diff --git a/.github/workflows/buildAndTestRyzenAI.yml b/.github/workflows/buildAndTestRyzenAI.yml index 1c42cf807..40735988b 100644 --- a/.github/workflows/buildAndTestRyzenAI.yml +++ b/.github/workflows/buildAndTestRyzenAI.yml @@ -137,7 +137,12 @@ jobs: # ninja check-air-e2e-chess # Programming examples set 1: peano tests (retry once on failure for flaky NPU tests) - ninja check-programming-examples-peano || ninja check-programming-examples-peano + # HF_TOKEN exposes the repository secret for tests requiring gated + # Hugging Face model downloads (e.g. llama32_1b/run_npu2_verify.lit). + # Tests without REQUIRES: hf_token are unaffected. + HF_TOKEN="${{ secrets.HF_TOKEN }}" \ + ninja check-programming-examples-peano || \ + HF_TOKEN="${{ secrets.HF_TOKEN }}" ninja check-programming-examples-peano # Chess tests disabled to reduce CI time. Uncomment to re-enable: # ninja check-programming-examples-chess diff --git a/programming_examples/lit.cfg.py b/programming_examples/lit.cfg.py index 7a7f86ec9..29f0a4d3f 100644 --- a/programming_examples/lit.cfg.py +++ b/programming_examples/lit.cfg.py @@ -124,6 +124,16 @@ config.substitutions.append(("%xrt_flags", xrt_flags)) config.substitutions.append(("%XRT_DIR", config.xrt_dir)) +# Tests that download Hugging Face Hub gated models (e.g. meta-llama/*) need +# HF_TOKEN to be set. Mark `hf_token` as available only when the env var is +# present so REQUIRES: hf_token tests skip cleanly on machines without it. +if os.environ.get("HF_TOKEN"): + config.available_features.add("hf_token") + llvm_config.with_environment("HF_TOKEN", os.environ["HF_TOKEN"]) + print("HF_TOKEN found in environment; hf_token feature enabled.") +else: + print("HF_TOKEN not set; hf_token feature disabled.") + llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) llvm_config.use_default_substitutions() diff --git a/programming_examples/llama32_1b/.gitignore b/programming_examples/llama32_1b/.gitignore index 8234f3a99..b9c52fc77 100644 --- a/programming_examples/llama32_1b/.gitignore +++ b/programming_examples/llama32_1b/.gitignore @@ -6,6 +6,16 @@ __pycache__/ kernel_cache/ air_project/ .debug/ +.pytest_cache/ + +# Stray artifacts from running scripts outside build_*/ (xrt.py + external_kernels.py +# write these to CWD by design — `make compile/run/verify` cd into BUILD_DIR first, +# but ad-hoc `python3 verify/verify_runner.py` from this dir will leak them here). +air.mlir +air.elf +air.xclbin +air.insts.bin +*.o # Local-only experimental and ad-hoc test directories test_swiglu/ @@ -18,4 +28,4 @@ flash_attn_issue/ docs/development_progress/ docs/report/ docs/issues/ -test/ +test_hf_model/ diff --git a/programming_examples/llama32_1b/Makefile b/programming_examples/llama32_1b/Makefile index 65ca843ec..a60267c24 100644 --- a/programming_examples/llama32_1b/Makefile +++ b/programming_examples/llama32_1b/Makefile @@ -26,16 +26,7 @@ N_TOKENS ?= 1000 PROMPT ?= What is the capital of France? MODEL ?= instruct -# WEIGHTS=hf (default) — load real Meta weights from HuggingFace -# WEIGHTS=synthetic — deterministic random weights (no HF, for CI) -WEIGHTS ?= hf -ifeq ($(WEIGHTS),synthetic) - WEIGHTS_FLAG := --synthetic-weights -else - WEIGHTS_FLAG := -endif - -.PHONY: help compile run profile verify chat clean +.PHONY: help compile run profile chat verify diagnosis clean # ============================================================ # Help @@ -53,21 +44,23 @@ help: @echo " make profile Run with profiling breakdown" @echo "" @echo "More targets:" - @echo " make verify With CPU reference verification" + @echo " make verify Top-k token-level inclusion gate vs HF bf16 (8 prompts × 32 tokens, k=5)" + @echo " make diagnosis Per-layer ffn_out cosine + max_abs vs HF bf16 (single prompt, informational)" @echo "" @echo "Maintenance:" - @echo " make clean Remove all build artifacts" + @echo " make clean Remove all build artifacts and verify reports" @echo "" @echo "Options (override with make VAR=value):" - @echo " N_TOKENS=1000 Max decode tokens (instruct model stops early on EOT)" - @echo " PROMPT=\"...\" Input prompt text" + @echo " N_TOKENS=1000 Max decode tokens for run/profile/chat (instruct stops early on EOT)" + @echo " PROMPT=\"...\" Input prompt text (run/profile/diagnosis)" @echo " MODEL=base|instruct Model variant (default: instruct)" @echo "" @echo "Examples:" @echo " make run N_TOKENS=50" @echo " make run MODEL=base PROMPT=\"The capital of France is\" N_TOKENS=200" @echo " make profile PROMPT=\"How does photosynthesis work?\"" - @echo " make verify N_TOKENS=10" + @echo " make verify MODEL=base" + @echo " make diagnosis PROMPT=\"The capital of France is\"" # ============================================================ # Unified Pipeline (NPU prefill + NPU decode) @@ -81,31 +74,39 @@ compile: ## Run unified inference run: cd $(BUILD_DIR) && python3 $(srcdir)/llama32_1b_inference.py \ - --run-only --n-tokens $(N_TOKENS) --prompt "$(PROMPT)" --model $(MODEL) $(WEIGHTS_FLAG) + --run-only --n-tokens $(N_TOKENS) --prompt "$(PROMPT)" --model $(MODEL) ## Run with detailed profiling breakdown profile: cd $(BUILD_DIR) && python3 $(srcdir)/llama32_1b_inference.py \ - --run-only --n-tokens $(N_TOKENS) --profile --prompt "$(PROMPT)" --model $(MODEL) $(WEIGHTS_FLAG) - -## Run with CPU reference verification -verify: - cd $(BUILD_DIR) && python3 $(srcdir)/llama32_1b_inference.py \ - --run-only --n-tokens $(N_TOKENS) --verify --profile --prompt "$(PROMPT)" --model $(MODEL) $(WEIGHTS_FLAG) + --run-only --n-tokens $(N_TOKENS) --profile --prompt "$(PROMPT)" --model $(MODEL) ## Interactive chat: prepare runtime once, then loop on prompts chat: cd $(BUILD_DIR) && python3 $(srcdir)/llama32_1b_inference.py \ - --run-only --interactive --n-tokens $(N_TOKENS) --model $(MODEL) $(WEIGHTS_FLAG) + --run-only --interactive --n-tokens $(N_TOKENS) --model $(MODEL) ## Compile and run in one step all: compile profile +## Run the top-k token-level inclusion gate (NPU vs HF bf16, 8 prompts × 32 tokens, k=5) +verify: + @mkdir -p $(BUILD_DIR) + cd $(BUILD_DIR) && python3 $(srcdir)/verify/verify_runner.py \ + --prompts topk_token --model $(MODEL) + +## Run the diagnosis lens (per-layer ffn_out cosine vs HF bf16, single prompt, informational) +diagnosis: + @mkdir -p $(BUILD_DIR) + cd $(BUILD_DIR) && python3 $(srcdir)/verify/verify_runner.py \ + --prompts single --prompt "$(PROMPT)" --model $(MODEL) + # ============================================================ # Clean # ============================================================ -## Remove all build artifacts +## Remove all build artifacts and verify reports clean: rm -r $(BUILD_DIR) 2>/dev/null || true - @echo "Build directory removed. Run 'make compile' to rebuild." + rm -rf $(srcdir)/verify/reports + @echo "Build directory and verify/reports/ removed. Run 'make compile' to rebuild." diff --git a/programming_examples/llama32_1b/README.md b/programming_examples/llama32_1b/README.md index 7f1a4d81d..61fb6e541 100644 --- a/programming_examples/llama32_1b/README.md +++ b/programming_examples/llama32_1b/README.md @@ -6,8 +6,8 @@ End-to-end LLAMA-3.2-1B (1B parameter, BF16) inference running on AMD NPU2 (AIE2 | Phase | Time | vs IRON | |-------|------|---------| -| Prefill (2048 tokens) | 1.27s wall | **2.17x faster** | -| Decode | 92ms/token (10.8 tok/s) | **4.0x faster** | +| Prefill / TTFT (2048 tokens) | 1.27s wall | **2.17x faster** | +| Decode / TPOT (steady-state) | 92ms/token (10.8 tok/s) | **4.0x faster** | ## Prerequisites @@ -51,7 +51,8 @@ make run MODEL=base PROMPT="In 1969, the first man to walk on" N_TOKENS=200 # Run with profiling breakdown make profile -# Run with correctness verification +# Run the top-k token-level correctness gate (NPU vs HF transformers bf16, +# 8 prompts × 32 greedy tokens, k=5; ~4 min). See docs/VERIFICATION.html. make verify ``` @@ -61,8 +62,12 @@ make verify |-----|-------------| | [Architecture](ARCHITECTURE.md) | Per-layer kernel sequence, runtime flow, key design patterns | | [Usage Guide](docs/usage.md) | All `make` targets, command-line options, file structure | -| [Performance Profile](docs/profile.md) | Kernel timing breakdown, BO categories, memory model | -| [Implementation Guide](docs/explain.md) | How kernels are built, compiled, and stitched together | +| [Implementation Guide](docs/IMPLEMENTATION_GUIDE.html) | Long-form production codebase walkthrough: model math (Part A), NPU mapping (Part B), verification (Part C), future work (Part D) | +| [Verification](docs/VERIFICATION.html) | `make verify` (top-k token gate) + `make diagnosis` (per-layer cosine) — design, gates, reproduction | +| [Ablation Study](docs/ABLATION_STUDY.html) | 4-cell dispatch ablation quantifying each optimization's contribution (decode 2.83×, prefill 1.56×) | +| [Performance Profile (textual)](docs/profile.md) | Kernel timing breakdown, BO categories, memory model | +| [Performance Profile (visualization)](docs/PROFILE.html) | End-to-end dataflow diagram with per-step measured timing; BO Write / NPU Run / BO Read concept walkthrough | +| [Kernel Walkthrough](docs/explain.md) | How individual kernels are built, compiled, and stitched together | | [Known Issues](docs/issues.md) | BF16 precision, fixed seq_len, no sampling | ## Key Files @@ -73,7 +78,7 @@ make verify | `llama32_1b_prefill.py` | Standalone prefill (with profiler report) | | `llama32_1b_decode.py` | Standalone decode | | `llama32_1b_weights.py` | Weight loading from HuggingFace safetensors | -| `llama32_1b_reference.py` | CPU F32 reference implementation | +| `llama32_1b_cpu_helpers.py` | NumPy helpers shared by production + verify: `rms_norm` (LM-head GEMV final norm), `attention_reference` (prefill `cpu_attn=True` fallback), `softmax` (used by `attention_reference`). | | `kernel_builder/` | Shared utilities: MLIR stitching, kernel cache, external kernel compilation | | `multi_launch_builder/` | Multi-launch ELF builders (one per fused kernel) | -| `Makefile` | Build/run/profile/verify targets | +| `Makefile` | Build / run / profile / chat / verify / diagnosis targets | diff --git a/programming_examples/llama32_1b/ablation/.gitignore b/programming_examples/llama32_1b/ablation/.gitignore new file mode 100644 index 000000000..edadeea50 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/.gitignore @@ -0,0 +1,2 @@ +build/ +standalone_cache/ diff --git a/programming_examples/llama32_1b/ablation/README.md b/programming_examples/llama32_1b/ablation/README.md new file mode 100644 index 000000000..90c5e8164 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/README.md @@ -0,0 +1,35 @@ +# Llama-3.2-1B NPU2 Ablation Study + +4-cell controlled measurement of how each dispatch optimization (multi-launch +ELF stitching, per-layer weight BOs, shared intermediate BOs) contributes to +the production runtime. + +Two sister studies: + +| Subdir | Scope | Cell D headline | +|---|---|---| +| [`decode/`](decode/) | Full per-token loop: 16 × (rms_gemv_rope + decode_attention_cpu + o_gemv_ffn) + LM head + argmax | 90.65 ms/token; A→D = **2.83×** | +| [`prefill/`](prefill/) | Full 16-layer prefill: 16 × (rms_gemms_rope + FA + o_ffn) | 1.13 s/pass; A→D = **1.56×** | + +Both studies use the same 4-cell ladder (A naive → B + per-layer weight BOs +→ C + shared intermediate BOs → D production-merged), bit-exact validation +against committed Cell D goldens, and the NPU exclusive-lock timing +protocol. + +**Audience-facing walkthrough**: [`../docs/ABLATION_STUDY.html`](../docs/ABLATION_STUDY.html) +— headline numbers, methodology, cross-comparison. + +**Reproducibility** (each subdir is self-contained): + +```sh +cd decode/ && make all # ~10 min, NPU-locked +cd prefill/ && make all # ~15 min, NPU-locked +``` + +## Companion docs (in repo) + +- [`../docs/IMPLEMENTATION_GUIDE.html`](../docs/IMPLEMENTATION_GUIDE.html) — production codebase walkthrough; B3-B7 describes the four gaps that the cells ablate +- [`../docs/profile.md`](../docs/profile.md) — production runtime numbers reproduced by Cell D +- `docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md` — prefill spec +- `docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md` — decode spec +- `docs/plans/...` — corresponding step-by-step implementation plans diff --git a/programming_examples/llama32_1b/ablation/decode/.gitignore b/programming_examples/llama32_1b/ablation/decode/.gitignore new file mode 100644 index 000000000..2c9a7ca66 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/.gitignore @@ -0,0 +1,15 @@ +# Build / kernel cache artifacts +build/ +air_project/ +__pycache__/ +*.pyc + +# Compiled NPU kernel objects (generated by Peano during make compile) +*.o +*.elf +*.mlir +*.insts.bin + +# Run artifacts (regenerated each `make run`) +results_*.json +report_*.md diff --git a/programming_examples/llama32_1b/ablation/decode/Makefile b/programming_examples/llama32_1b/ablation/decode/Makefile new file mode 100644 index 000000000..1d58f8fb2 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/Makefile @@ -0,0 +1,38 @@ +# Llama-3.2-1B Plan 2 (full decode) ablation harness +# +# make compile — compile all 4 cells' ELFs + LM head (~5-10 min, cached) +# make regen-golden — regenerate committed golden fixtures (rare; only after Cell D changes) +# make run — run all 4 cells, 5 trials each, emit JSON +# make report — generate markdown report from latest results JSON +# make test — NPU-free unit tests (kv_cache + validation gate) +# make all — compile + run + report +# make clean — wipe build/ + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +BUILD := build + +.PHONY: help compile regen-golden run report test all clean + +help: + @echo "make compile | regen-golden | run | report | test | all | clean" + +compile: + @mkdir -p $(BUILD) + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../..:$(srcdir)/../prefill:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -c "from cells.cell_d_merged import compile_cell_d; from cells.lm_head_const import compile_lm_head; from kernel_builder.cache import KernelCache; from golden.regen_golden import CONFIG; c = KernelCache(cache_dir='.', verbose=True); c.load_manifest(); compile_cell_d(c, CONFIG); compile_lm_head(c, CONFIG)" + +regen-golden: compile + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../prefill:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 $(srcdir)/golden/regen_golden.py + +run: compile + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../prefill:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 $(srcdir)/run_ablation.py --out results_latest.json + +report: + cd $(BUILD) && python3 $(srcdir)/analyze.py results_latest.json > report_latest.md && cat report_latest.md + +test: + cd $(srcdir) && python3 -m pytest tests/ -v + +all: compile run report + +clean: + rm -rf $(BUILD) diff --git a/programming_examples/llama32_1b/ablation/decode/README.md b/programming_examples/llama32_1b/ablation/decode/README.md new file mode 100644 index 000000000..b5648a131 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/README.md @@ -0,0 +1,97 @@ +# Llama-3.2-1B Plan 2 (Full Decode) Ablation + +Bit-exact 4-cell ablation of the production **decode** pipeline: +`rms_gemv_rope` (6 sub-launches) + `decode_attention_cpu` (invariant) + +`o_gemv_ffn` (8 sub-launches) per layer × 16 layers + final RMSNorm + +`lm_head_gemv` (invariant) + argmax. + +Per-trial timed unit: **one decode token** at fixed `current_pos = 7` +(after a 7-token synthetic pre-fill of the KV cache). 5 trials, drop trial 1 +as warmup, median + (min, max) over remaining 4. + +Companion docs: +- Spec: [`../docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md`](../docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md) +- Plan: [`../docs/plans/2026-05-12-llama32-1b-ablation-plan2-fulldecode-plan.md`](../docs/plans/2026-05-12-llama32-1b-ablation-plan2-fulldecode-plan.md) +- Sister study (prefill): [`../prefill/README.md`](../prefill/README.md) +- Audience-facing summary: [`../../docs/ABLATION_STUDY.html`](../../docs/ABLATION_STUDY.html) + +## What this measures + +Four cells, identical computation, different dispatch strategy. CPU attention +and LM head are held INVARIANT across all 4 cells. + +| Cell | What changes within each kernel-group | Adds | +|------|---------------------------------------|------| +| A | 6+8 separate `xrt.run()` per layer, host round-trip on every intermediate | (baseline) | +| B | + per-layer weight BOs (`static_input_indices`) | #2 | +| C | + shared intermediate BOs across separate `xrt.run()` calls (within each group) | #3 | +| D | + multi-launch merging (production: 6→1 + 8→1 ELF per layer) | #1 | + +NPU calls per token (16 layers + LM head): +- Cell A/B/C: **(6 + 8) × 16 + 1 = 225 dispatches** (LM head invariant-merged) +- Cell D: **(1 + 1) × 16 + 1 = 33 dispatches** + +## Quick start + +``` +make compile # one-time, ~5-10 min for all 4 cells' ELFs + LM head +make run # 4 cells × 5 trials (~2-3 min, NPU-locked) +make report # markdown report +``` + +## Validation gate + +Every cell must produce **bit-identical** output bytes vs. committed Cell D +goldens for both kernel-groups (`golden_rms_gemv_rope_decode.npz`, +`golden_o_gemv_ffn_decode.npz`). Cells failing the gate suppress their timing. + +## Reproducibility + +``` +cd programming_examples/llama32_1b/ablation/decode +make clean +make all +``` + +NPU-free unit tests (smoke test the harness scaffolding): + +``` +make test +``` + +Expected: **8 passed** (4 KV-cache state tests + 4 validation-gate tests). + +## File map + +| Path | Purpose | +|------|---------| +| `specs/kernel_group.py` | Re-export prefill study's frozen dataclasses | +| `specs/rms_gemv_rope.py` | Concrete spec for the 6-launch decode attention pre-block | +| `specs/o_gemv_ffn.py` | Concrete spec for the 8-launch decode FFN block | +| `standalone_builders/rms_gemv_rope.py` | 6 single-launch builders + STANDALONES registry | +| `standalone_builders/o_gemv_ffn.py` | 8-element STANDALONES registry derived from spec | +| `cells/kernel_group.py` (re-export) + `cells/common.py` (re-export) | Shared infrastructure | +| `cells/cell_a_naive.py` | Cell A — copy of Plan 1 with decode-spec branches added | +| `cells/cell_b_static.py` | Cell B — same | +| `cells/cell_c_charitable.py` | Cell C — same | +| `cells/cell_d_merged.py` | Cell D — production-merged decode dispatches | +| `cells/decode_attn_const.py` | Invariant CPU attention runner | +| `cells/lm_head_const.py` | Invariant 8-partition LM head runner | +| `cells/per_token_loop.py` | The end-to-end timed unit | +| `cells/kv_cache.py` | Deterministic KV-cache init + per-trial reset | +| `golden/regen_golden.py` | Cell-D one-shot to regenerate goldens | +| `golden/golden_*.npz` | Two committed bf16 goldens + meta json | +| `validate.py` | Bit-exact gate (re-export of Plan 1's parameterized validator) | +| `run_ablation.py` | Orchestrator — compile, preload, validate, time × 4 cells | +| `analyze.py` | JSON → markdown report | +| `Makefile` | Convenience targets | +| `tests/` | NPU-free unit tests | + +## Limitations + +- Single token at fixed position. By design (see spec §5): keeps `decode_attention_cpu` + CPU work constant across trials, isolates dispatch overhead. Position-dependent + multi-token decode is out of scope. +- Synthetic seed=42 weights only. No HuggingFace. +- LM head held INVARIANT across cells. A potential follow-up could ablate it. +- NPU FlashAttention decode path NOT measured. Production uses CPU attention at head_dim=64. diff --git a/programming_examples/llama32_1b/ablation/decode/__init__.py b/programming_examples/llama32_1b/ablation/decode/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/decode/analyze.py b/programming_examples/llama32_1b/ablation/decode/analyze.py new file mode 100644 index 000000000..d8154b5f2 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/analyze.py @@ -0,0 +1,117 @@ +"""Generate a markdown report from a Plan 2 results.json. + +Usage: + python3 analyze.py results.json > report.md +""" + +import json +import sys + + +def fmt_ms(s): + return f"{s * 1000:.2f} ms" + + +def main(): + if len(sys.argv) < 2: + print("usage: python3 analyze.py results.json", file=sys.stderr) + sys.exit(1) + + with open(sys.argv[1]) as f: + r = json.load(f) + + print("# Plan 2 (full decode) ablation report") + print() + print( + f"- current_pos: **{r['current_pos']}** (after a {r['prompt_len']}-token prefill)" + ) + print( + f"- trials per cell: **{r['trials']}** (drop trial 1 as warmup, median of remaining)" + ) + print(f"- per timed trial: ONE decode token through 16 layers + LM head + argmax") + print() + + cells = r["cells"] + cell_order = ["A", "B", "C", "D"] + cell_labels = { + "A": "Naive no-merge", + "B": "+ per-layer weight BOs (#2)", + "C": "+ shared intermediate BOs (#3)", + "D": "+ multi-launch merging (#1) [production]", + } + + print("## Per-token total wall time") + print() + print("| Cell | Median | Range | Δ vs prev | Speedup vs prev |") + print("|------|--------|-------|-----------|-----------------|") + + prev_median = None + baseline = None + for c in cell_order: + if c not in cells: + continue + d = cells[c] + if "median_total_s" not in d: + print(f"| {c} {cell_labels[c]} | — | VALIDATION FAIL | — | — |") + continue + med = d["median_total_s"] + rng = f"[{fmt_ms(d['min_total_s'])}, {fmt_ms(d['max_total_s'])}]" + if prev_median is None: + delta = "—" + speed = "(baseline)" + baseline = med + else: + delta = f"{(prev_median - med) * 1000:+.2f} ms" + speed = f"{prev_median / med:.2f}×" if med > 0 else "—" + print( + f"| **{c}** {cell_labels[c]} | {fmt_ms(med)} | {rng} | {delta} | {speed} |" + ) + prev_median = med + + if baseline is not None and "D" in cells and "median_total_s" in cells["D"]: + a_to_d = baseline / cells["D"]["median_total_s"] + print() + print(f"**A → D total speedup: {a_to_d:.2f}×**") + print() + + print("## Per-kernel-group medians (single call)") + print() + print("| Cell | rms_gemv_rope median | o_gemv_ffn median |") + print("|------|----------------------|-------------------|") + for c in cell_order: + if c not in cells or "rms_gemv_rope_per_call_median_s" not in cells[c]: + continue + d = cells[c] + print( + f"| {c} | {fmt_ms(d['rms_gemv_rope_per_call_median_s'])} " + f"| {fmt_ms(d['o_gemv_ffn_per_call_median_s'])} |" + ) + print() + + print("## Component breakdown (Cell D, fixed costs)") + print() + if "D" in cells and "cpu_attn_total_median_s" in cells["D"]: + d = cells["D"] + print( + f"- CPU attention floor (sum across 16 layers): **{fmt_ms(d['cpu_attn_total_median_s'])}**" + ) + print( + f"- LM head (production-merged, invariant): **{fmt_ms(d['lm_head_median_s'])}**" + ) + print(f"- Total per-token wall: **{fmt_ms(d['median_total_s'])}**") + print() + + print("## Validation") + print() + print("| Cell | Validation |") + print("|------|------------|") + for c in cell_order: + if c not in cells: + print(f"| {c} | (not run) |") + continue + v = cells[c].get("validation", "?") + print(f"| {c} | {v} |") + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/decode/cells/__init__.py b/programming_examples/llama32_1b/ablation/decode/cells/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/decode/cells/cell_a_naive.py b/programming_examples/llama32_1b/ablation/decode/cells/cell_a_naive.py new file mode 100644 index 000000000..0b090e122 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/cell_a_naive.py @@ -0,0 +1,320 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell A -- Naive no-merge for a generic KernelGroupSpec. + +Walks spec.sub_launches in order. For each sub-launch: + 1. Build the 3-element args list per the spec's slot semantics. + 2. Invoke cache.load_and_run with naive=True (writes everything, + reads everything every call). + 3. Store output in results dict keyed by sub.name. + +Cross-sub-launch data flows via the host (extracted to numpy in a results +dict, then passed to the next call as input). + +naive=True forces load_and_run to: + - set output_indices = list(range(len(inputs))) (read back all slots) + - skip static_input_indices and intermediate_indices optimizations + +The returned result[slot] is always a 1D flat numpy array. Baton-link values +are passed directly as inputs to downstream sub-launches; the BO write uses +raw bytes so 1D vs 2D shape does not matter as long as byte counts match. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.common import compile_standalone_kernels + + +def _output_shape_for(spec_name, sub_name, config): + """Return numpy shape of the output buffer for (spec_name, sub_name). + + The output buffer is allocated as zeros with this shape and passed at + sub.output_slot_in_standalone. The kernel writes into it; load_and_run + returns a 1D flat view (byte-compatible with the 2D shape). + """ + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + n_total = seq * emb + + if spec_name == "rms_gemms_rope": + return { + "rmsnorm": (seq, emb), + "q_gemm": (seq, emb), + "k_gemm": (seq, kv), + "v_gemm": (seq, kv), + "rope_q": (seq, emb), + "rope_k": (seq, kv), + }[sub_name] + + if spec_name == "o_ffn": + return { + "o_gemm": (seq, emb), + "res_add": (seq, emb), + "ffn_rmsnorm": (seq, emb), + "gate_gemm": (seq, hid), + "up_gemm": (seq, hid), + "swiglu": (seq, hid), + "down_gemm": (seq, emb), + "ffn_add": (n_total,), # 1D output (standalone emits 1D; see o_ffn.py) + }[sub_name] + + # ---- Decode (single-token, 1D outputs) ---- + if spec_name == "rms_gemv_rope": + return { + "rmsnorm": (emb,), + "q_gemv": (emb,), + "k_gemv": (kv,), + "v_gemv": (kv,), + "rope_q": (emb,), # n_heads * head_dim = 32*64 = emb + "rope_k": (kv,), # n_kv_heads * head_dim = 8*64 = kv + }[sub_name] + + if spec_name == "o_gemv_ffn": + return { + "o_gemv": (emb,), + "add_attn_residual": (emb,), + "ffn_rmsnorm": (emb,), + "gate_gemv": (hid,), + "up_gemv": (hid,), + "swiglu": (hid,), + "down_gemv_k8192": (emb,), + "add_ffn_residual": (emb,), + }[sub_name] + + raise ValueError(f"unknown spec {spec_name!r}") + + +def _static_input_for(spec_name, sub_name, slot, layer_inputs): + """Return the static (weight/LUT/layer-level) array for this slot, or None. + + Returns None when the slot should come from a baton link (upstream + sub-launch output) or from the output buffer. + """ + if spec_name == "rms_gemms_rope": + # Slot conventions (from rms_gemms_rope.py docstring): + # rmsnorm: (x_in[slot0], norm_w[slot1], out[slot2]) + # gemm: (A[slot0], B_weight[slot1], C[slot2]) + # rope_2d: (in[slot0], lut[slot1], out[slot2]) + if sub_name == "rmsnorm": + if slot == 0: + return layer_inputs["x_in"] + if slot == 1: + return layer_inputs["norm_w"] + elif sub_name == "q_gemm": + if slot == 1: + return layer_inputs["wq"] + # slot 0 comes from rmsnorm baton + elif sub_name == "k_gemm": + if slot == 1: + return layer_inputs["wk"] + # slot 0 comes from rmsnorm baton + elif sub_name == "v_gemm": + if slot == 1: + return layer_inputs["wv"] + # slot 0 comes from rmsnorm baton + elif sub_name == "rope_q": + if slot == 1: + return layer_inputs["lut_q"] + # slot 0 comes from q_gemm baton + elif sub_name == "rope_k": + if slot == 1: + return layer_inputs["lut_k"] + # slot 0 comes from k_gemm baton + return None + + if spec_name == "o_ffn": + # Slot conventions (from o_ffn.py docstring): + # gemm: (A[slot0], B_weight[slot1], C[slot2]) + # add_2d_to_2d: (A[slot0], B[slot1], C[slot2]) no weight + # rmsnorm: (x[slot0], w[slot1], out[slot2]) + # swiglu_2d: (gate[slot0], up[slot1], out[slot2]) no weight + # ffn_add: (A[slot0], B[slot1], out[slot2]) no weight + if sub_name == "o_gemm": + if slot == 0: + return layer_inputs["attn_out"] + if slot == 1: + return layer_inputs["wo"] + elif sub_name == "res_add": + # slot0 = proj (from o_gemm baton); slot1 = x_residual (static) + if slot == 1: + return layer_inputs["x_residual"] + elif sub_name == "ffn_rmsnorm": + if slot == 1: + return layer_inputs["ffn_norm_w"] + # slot 0 comes from res_add baton + elif sub_name == "gate_gemm": + if slot == 1: + return layer_inputs["w_gate"] + # slot 0 comes from ffn_rmsnorm baton + elif sub_name == "up_gemm": + if slot == 1: + return layer_inputs["w_up"] + # slot 0 comes from ffn_rmsnorm baton + elif sub_name == "swiglu": + # both slot0 (gate) and slot1 (up) come from batons + pass + elif sub_name == "down_gemm": + if slot == 1: + return layer_inputs["w_down"] + # slot 0 comes from swiglu baton + elif sub_name == "ffn_add": + # slot0 = down (from down_gemm baton); slot1 = res1 (from res_add baton) + pass + return None + + # ---- Decode kernel-groups ---- + # CRITICAL: GEMV slot convention differs from prefill GEMM! + # gemv: (W_weight[slot0], x[slot1], y[slot2]) ← W is at slot 0, NOT slot 1 + if spec_name == "rms_gemv_rope": + # Slot conventions for decode rms_gemv_rope sub-launches: + # rmsnorm: (x_in[slot0], norm_w[slot1], out[slot2]) + # gemv: (W[slot0], x[slot1], y[slot2]) + # rope: (in[slot0], lut[slot1], out[slot2]) + if sub_name == "rmsnorm": + if slot == 0: + return layer_inputs["x_in"] + if slot == 1: + return layer_inputs["norm_w"] + elif sub_name == "q_gemv": + if slot == 0: + return layer_inputs["wq"] + # slot 1 (x = normed) comes from rmsnorm baton + elif sub_name == "k_gemv": + if slot == 0: + return layer_inputs["wk"] + elif sub_name == "v_gemv": + if slot == 0: + return layer_inputs["wv"] + elif sub_name == "rope_q": + if slot == 1: + return layer_inputs["lut_q"] + # slot 0 (in = q) comes from q_gemv baton + elif sub_name == "rope_k": + if slot == 1: + return layer_inputs["lut_k"] + return None + + if spec_name == "o_gemv_ffn": + # Slot conventions for decode o_gemv_ffn sub-launches: + # gemv: (W[slot0], x[slot1], y[slot2]) + # add: (A[slot0], B[slot1], out[slot2]) no weight + # rmsnorm: (x[slot0], w[slot1], out[slot2]) + # swiglu: (gate[slot0], up[slot1], out[slot2]) no weight + if sub_name == "o_gemv": + if slot == 0: + return layer_inputs["wo"] + if slot == 1: + return layer_inputs["attn_out"] + elif sub_name == "add_attn_residual": + # slot 0 = proj (from o_gemv baton); slot 1 = x_residual + if slot == 1: + return layer_inputs["x_residual"] + elif sub_name == "ffn_rmsnorm": + if slot == 1: + return layer_inputs["ffn_norm_w"] + # slot 0 (x = res1) comes from add_attn_residual baton + elif sub_name == "gate_gemv": + if slot == 0: + return layer_inputs["w_gate"] + # slot 1 (x = normed2) comes from ffn_rmsnorm baton + elif sub_name == "up_gemv": + if slot == 0: + return layer_inputs["w_up"] + elif sub_name == "swiglu": + # both slot 0 (gate) and slot 1 (up) come from batons + pass + elif sub_name == "down_gemv_k8192": + if slot == 0: + return layer_inputs["w_down"] + # slot 1 (x = swiglu) comes from swiglu baton + elif sub_name == "add_ffn_residual": + # slot 0 = down (from down_gemv baton); slot 1 = res1 (from add_attn baton) + pass + return None + + raise ValueError(f"unknown spec {spec_name!r}") + + +def compile_cell_a(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group into cache.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +def run_cell_a(cache, spec, layer_inputs, config, backend_preset, layer_idx=0): + """Run all spec.sub_launches sequentially with naive=True. + + Each sub-launch is a separate xrt.run() call. All host<->device transfers + are done unconditionally (naive=True means no skipping of static or + intermediate buffers). + + Args: + cache: KernelCache with manifested artifacts. + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + layer_inputs: dict of numpy arrays keyed by semantic name + (e.g. "x_in", "norm_w", "wq", "attn_out", etc.). + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + layer_idx: layer index (unused in Cell A, present for API consistency). + + Returns: + dict keyed by sub.name -> 1D flat numpy array of that sub-launch's + output, plus "_wall_s" for total wall time. + """ + # Strip instance_name; compile_cell_a sets it per-kernel. + backend = {**backend_preset} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + out_shape = _output_shape_for(spec.name, sub.name, config) + out_buf = np.zeros(out_shape, dtype=bfloat16) + + # Build the 3-arg list (all standalones have exactly 3 args). + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = out_buf + continue + + # Try static (weight/layer-level) lookup first. + v = _static_input_for(spec.name, sub.name, slot, layer_inputs) + if v is not None: + args[slot] = v + continue + + # Otherwise this slot is fed by an upstream baton link. + for link in spec.baton_links: + if link.consumer_idx == idx and link.consumer_in_slot == slot: + producer_name = spec.sub_launches[link.producer_idx].name + args[slot] = results[producer_name] + break + + assert args[slot] is not None, ( + f"[cell_a] no source found for {spec.name}/{sub.name} slot={slot}. " + f"Check baton_links and _static_input_for." + ) + + kernel_name = f"{spec.name}__{sub.name}" + result = cache.load_and_run( + kernel_name, + backend, + *args, + naive=True, + ) + # naive=True sets output_indices = list(range(3)), so result is a 3-tuple. + # The output is at sub.output_slot_in_standalone. + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results diff --git a/programming_examples/llama32_1b/ablation/decode/cells/cell_b_static.py b/programming_examples/llama32_1b/ablation/decode/cells/cell_b_static.py new file mode 100644 index 000000000..e4c1353e7 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/cell_b_static.py @@ -0,0 +1,270 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell B -- Cell A + per-layer weight BOs + static_input_indices. + +Same dataflow as Cell A (walks spec.sub_launches, threads via baton links), +but weights are pre-loaded once into per-layer BOs during preload phase. +The timed run phase skips the weight host->device sync via static_input_indices. + +Two public phases: + + preload_cell_b(cache, spec, weights_per_layer, config, backend_preset) + Called once before timing. For each (layer_idx, sub_launch): + - Builds a 3-arg list with the actual weight at weight_slot_in_standalone + and dummy zeros at all other slots. + - Calls load_and_run with output_indices=[output_slot], + static_input_indices={weight_slot}, and + bo_key=f"B_{spec.name}_{sub.name}_L{layer_idx}". + Sub-launches with weight_slot_in_standalone=None are skipped (no weight + to preload; those sub-launches just use default bo_key in the timed run). + + run_cell_b(cache, spec, layer_inputs, config, backend_preset, layer_idx=0) + Same loop as Cell A but: + - No naive=True. + - Passes static_input_indices={sub.weight_slot_in_standalone} (or empty + set if None) and output_indices=[sub.output_slot_in_standalone]. + - Passes bo_key=f"B_{spec.name}_{sub.name}_L{layer_idx}" -- must + byte-match the preload bo_key. + +Helpers _output_shape_for and _static_input_for are imported from cell_a_naive +to avoid duplication. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.cell_a_naive import _output_shape_for, _static_input_for +from cells.common import compile_standalone_kernels + + +def _activation_shape_for(spec_name, sub_name, config): + """Return the numpy shape of the activation (non-weight, non-output) input slot. + + This is needed during preload to allocate a correctly-sized dummy BO for the + activation slot. All current standalones have exactly 3 args: + (activation, weight, output). The activation is always at slot 0. + + Shapes must match what _static_input_for / baton links would supply at + run time, because the BO is allocated on the first call (preload) and + reused on subsequent calls (run). A size mismatch raises a ValueError + inside KernelCache.load_and_run when it tries to copy src into the BO. + """ + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + + if spec_name == "rms_gemms_rope": + # All sub-launches: activation at slot 0 is either x_in (seq,emb) or + # the normed/q/k output fed via baton -- all are (seq, emb) or (seq, kv). + return { + # rmsnorm: x_in is (seq, emb) + "rmsnorm": (seq, emb), + # gemms: A input is (seq, emb) -- the normed activation + "q_gemm": (seq, emb), + "k_gemm": (seq, emb), + "v_gemm": (seq, emb), + # ropes: activation slot is the q/k output + "rope_q": (seq, emb), + "rope_k": (seq, kv), + }[sub_name] + + if spec_name == "o_ffn": + return { + # o_gemm: activation = attn_out (seq, emb) + "o_gemm": (seq, emb), + # ffn_rmsnorm: activation = res1 (seq, emb) + "ffn_rmsnorm": (seq, emb), + # gate/up gemms: activation = normed2 (seq, emb) + "gate_gemm": (seq, emb), + "up_gemm": (seq, emb), + # down_gemm: activation = swiglu (seq, hid) + "down_gemm": (seq, hid), + }[sub_name] + + # ---- Decode (single-token, 1D activations) ---- + if spec_name == "rms_gemv_rope": + # All activations are 1D. The activation slot is whichever non-weight, + # non-output slot exists; preload sets a dummy of this size in any + # missing slot. + return { + "rmsnorm": (emb,), # x_in at slot 0 + "q_gemv": (emb,), # x at slot 1 (input dim K=emb) + "k_gemv": (emb,), # x at slot 1 + "v_gemv": (emb,), # x at slot 1 + "rope_q": (emb,), # in at slot 0 (n_heads * head_dim = emb) + "rope_k": (kv,), # in at slot 0 (n_kv_heads * head_dim = kv) + }[sub_name] + + if spec_name == "o_gemv_ffn": + return { + "o_gemv": (emb,), # attn_out at slot 1 + "add_attn_residual": (emb,), # A & B at slots 0,1 both (emb,) + "ffn_rmsnorm": (emb,), # res1 at slot 0 + "gate_gemv": (emb,), # normed2 at slot 1 (input dim K=emb) + "up_gemv": (emb,), # normed2 at slot 1 + "swiglu": (hid,), # gate, up both (hid,) + "down_gemv_k8192": (hid,), # swiglu at slot 1 (input dim K=hid) + "add_ffn_residual": (emb,), # A & B at slots 0,1 + }[sub_name] + + raise ValueError(f"unknown spec {spec_name!r} or sub {sub_name!r}") + + +def compile_cell_b(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group into cache.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +def preload_cell_b(cache, spec, weights_per_layer, config, backend_preset): + """Pre-load per-layer weights into dedicated BOs. + + For each (layer_idx, weights) pair and each sub-launch with a weight slot, + run a one-shot load_and_run that writes the weight into the BO. Subsequent + timed runs reuse the same BO (identified by bo_key) and skip the write. + + Args: + cache: KernelCache with manifested artifacts. + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + weights_per_layer: list of dicts (one per layer), each keyed by semantic + weight name (same keys accepted by _static_input_for / Cell A). + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + for layer_idx, layer_weights in enumerate(weights_per_layer): + for sub in spec.sub_launches: + if sub.weight_slot_in_standalone is None: + # No weight slot -- nothing to preload for this sub-launch. + continue + + out_shape = _output_shape_for(spec.name, sub.name, config) + out_buf = np.zeros(out_shape, dtype=bfloat16) + + # Build the 3-arg list: weight at weight_slot, output at output_slot, + # dummy zeros at remaining slot(s). + args = [None, None, None] + weight_slot = sub.weight_slot_in_standalone + output_slot = sub.output_slot_in_standalone + args[output_slot] = out_buf + + # Retrieve the weight array using the same lookup as Cell A. + weight_arr = _static_input_for( + spec.name, sub.name, weight_slot, layer_weights + ) + assert weight_arr is not None, ( + f"[cell_b preload] _static_input_for returned None for " + f"{spec.name}/{sub.name} slot={weight_slot}. " + f"Check weight keys in weights_per_layer." + ) + args[weight_slot] = weight_arr + + # Fill any remaining slot with a correctly-sized dummy zero array. + # The BO is allocated on this first call and reused in run_cell_b; + # the size must match what the real activation will supply. + for slot in range(3): + if args[slot] is None: + act_shape = _activation_shape_for(spec.name, sub.name, config) + args[slot] = np.zeros(act_shape, dtype=bfloat16) + + bo_key = f"B_{spec.name}_{sub.name}_L{layer_idx}" + kernel_name = f"{spec.name}__{sub.name}" + + cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[output_slot], + static_input_indices={weight_slot}, + bo_key=bo_key, + ) + + +def run_cell_b(cache, spec, layer_inputs, config, backend_preset, layer_idx=0): + """Run all spec.sub_launches sequentially with pre-loaded weight BOs. + + Same dataflow as Cell A (batons via results dict) but: + - Uses static_input_indices={weight_slot} to skip weight write on this call. + - Uses output_indices=[output_slot] instead of naive read-all. + - Uses bo_key matching the preload phase so the same BO set is reused. + + Sub-launches with weight_slot_in_standalone=None (e.g. swiglu, ffn_add) + have no static weight -- they use an empty static_input_indices set and + the same bo_key pattern for BO identity. + + Args: + cache: KernelCache with manifested artifacts. + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + layer_inputs: dict of numpy arrays keyed by semantic name. + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + layer_idx: layer index used to select the right pre-loaded BO set. + + Returns: + dict keyed by sub.name -> 1D flat numpy array of that sub-launch's + output, plus "_wall_s" for total wall time. + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + out_shape = _output_shape_for(spec.name, sub.name, config) + out_buf = np.zeros(out_shape, dtype=bfloat16) + + # Build the 3-arg list (all standalones have exactly 3 args). + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = out_buf + continue + + # Try static (weight/layer-level) lookup first. + v = _static_input_for(spec.name, sub.name, slot, layer_inputs) + if v is not None: + args[slot] = v + continue + + # Otherwise this slot is fed by an upstream baton link. + for link in spec.baton_links: + if link.consumer_idx == idx and link.consumer_in_slot == slot: + producer_name = spec.sub_launches[link.producer_idx].name + args[slot] = results[producer_name] + break + + assert args[slot] is not None, ( + f"[cell_b] no source found for {spec.name}/{sub.name} slot={slot}. " + f"Check baton_links and _static_input_for." + ) + + # Determine static_input_indices for this sub-launch. + if sub.weight_slot_in_standalone is not None: + static_indices = {sub.weight_slot_in_standalone} + else: + static_indices = set() + + kernel_name = f"{spec.name}__{sub.name}" + bo_key = f"B_{spec.name}_{sub.name}_L{layer_idx}" + + result = cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[sub.output_slot_in_standalone], + static_input_indices=static_indices, + bo_key=bo_key, + ) + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results diff --git a/programming_examples/llama32_1b/ablation/decode/cells/cell_c_charitable.py b/programming_examples/llama32_1b/ablation/decode/cells/cell_c_charitable.py new file mode 100644 index 000000000..7871ab1ea --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/cell_c_charitable.py @@ -0,0 +1,308 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell C -- Cell B + shared intermediate BOs across separate xrt.run() calls, +parameterized over a KernelGroupSpec. Walks spec.baton_links to alias BOs. + +Two public phases: + + preload_cell_c(cache, spec, weights_per_layer, config, backend_preset) + Called once before timing. For each (layer_idx, layer_weights) pair: + 1. Run each sub-launch once (allocates BOs and writes weights via + static_input_indices). Uses bo_key=f"C_{spec.name}_{sub.name}_L{li}". + 2. Walk spec.baton_links and alias each producer's output BO into + the consumer's input BO slot via _share_bo. + + run_cell_c(cache, spec, layer_inputs, config, backend_preset, layer_idx=0) + Same dataflow as Cell B but with: + - bo_key=f"C_{spec.name}_{sub.name}_L{layer_idx}" (matches preload). + - intermediate_indices: producer output slots and consumer input slots + that are baton-managed (host skips writing those BOs). + +For a baton-aliased slot, a np.zeros placeholder is passed to load_and_run; +the bytes are NOT written to device because the slot is in intermediate_indices. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.cell_a_naive import _output_shape_for, _static_input_for +from cells.common import compile_standalone_kernels, _share_bo + +# --------------------------------------------------------------------------- +# Compile (same registry walk as Cell A / Cell B) +# --------------------------------------------------------------------------- + + +def compile_cell_c(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group into cache.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +# --------------------------------------------------------------------------- +# Shape helpers +# --------------------------------------------------------------------------- + + +def _slot_shape_for(spec_name, sub_name, slot, config): + """Return the numpy shape for an arbitrary (sub_name, slot) pair. + + Covers both weight slots and activation/baton slots so that the preload + loop can allocate correctly-sized BOs for all sub-launches, including + those with no weight slot (res_add, swiglu, ffn_add). + + For weight slots this returns the weight shape (2-D for GEMMs, 1-D for + norms/LUTs). For activation/baton slots it returns the activation shape. + """ + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + + if spec_name == "rms_gemms_rope": + # slot 2 = output for every sub-launch; handled by _output_shape_for. + table = { + # slot0 slot1 + "rmsnorm": [(seq, emb), (emb,)], + "q_gemm": [(seq, emb), (emb, emb)], + "k_gemm": [(seq, emb), (emb, kv)], + "v_gemm": [(seq, emb), (emb, kv)], + "rope_q": [(seq, emb), (seq * emb,)], + "rope_k": [(seq, kv), (seq * kv,)], + } + return table[sub_name][slot] + + if spec_name == "o_ffn": + table = { + # slot0 slot1 + "o_gemm": [(seq, emb), (emb, emb)], + "res_add": [(seq, emb), (seq, emb)], + "ffn_rmsnorm": [(seq, emb), (emb,)], + "gate_gemm": [(seq, emb), (emb, hid)], + "up_gemm": [(seq, emb), (emb, hid)], + "swiglu": [(seq, hid), (seq, hid)], + "down_gemm": [(seq, hid), (hid, emb)], + "ffn_add": [(seq, emb), (seq, emb)], + } + return table[sub_name][slot] + + # ---- Decode (single-token, 1D activations) ---- + # NOTE: GEMV slot convention is (W[slot0], x[slot1], y[slot2]) — W is at + # slot 0, NOT slot 1 like prefill GEMM. Tables encode actual decode shapes. + if spec_name == "rms_gemv_rope": + table = { + # slot0 slot1 + "rmsnorm": [(emb,), (emb,)], # x_in, norm_w + "q_gemv": [(emb, emb), (emb,)], # W, x (GEMV W at slot 0!) + "k_gemv": [(kv, emb), (emb,)], # W, x + "v_gemv": [(kv, emb), (emb,)], # W, x + "rope_q": [(emb,), (emb,)], # in, lut (lut is n_rows*head_dim flat) + "rope_k": [(kv,), (kv,)], # in, lut + } + return table[sub_name][slot] + + if spec_name == "o_gemv_ffn": + table = { + # slot0 slot1 + "o_gemv": [(emb, emb), (emb,)], # wo, attn_out + "add_attn_residual": [(emb,), (emb,)], # proj, x_residual + "ffn_rmsnorm": [(emb,), (emb,)], # res1, ffn_norm_w + "gate_gemv": [(hid, emb), (emb,)], # w_gate, normed2 + "up_gemv": [(hid, emb), (emb,)], # w_up, normed2 + "swiglu": [(hid,), (hid,)], # gate, up + "down_gemv_k8192": [(emb, hid), (hid,)], # w_down, swiglu + "add_ffn_residual": [(emb,), (emb,)], # down, res1 + } + return table[sub_name][slot] + + raise ValueError(f"unknown spec {spec_name!r} or sub {sub_name!r}") + + +# --------------------------------------------------------------------------- +# Baton-link helpers +# --------------------------------------------------------------------------- + + +def _intermediate_slots_for_sub(spec, sub_idx): + """For a given sub-launch index, return the set of slots that are + baton-managed (either produced or consumed via a baton link). + + These slots are passed as intermediate_indices to load_and_run so the + host skips writing them: + - Producer output slot: the kernel writes here; downstream reads from the + same BO via the alias. + - Consumer input slot: upstream already wrote to it via the shared BO; + host must not overwrite with zeros. + """ + slots = set() + for link in spec.baton_links: + if link.producer_idx == sub_idx: + slots.add(link.producer_out_slot) + if link.consumer_idx == sub_idx: + slots.add(link.consumer_in_slot) + return slots + + +# --------------------------------------------------------------------------- +# Preload phase +# --------------------------------------------------------------------------- + + +def preload_cell_c(cache, spec, weights_per_layer, config, backend_preset): + """One-shot allocation: run each sub-launch once to materialise BOs, then + alias intermediate BOs across sub-launches per spec.baton_links. + + Phase 1 (inner loop over sub_launches): Each sub-launch is invoked once + with its actual weight in place and dummy zeros for all other inputs. + This causes KernelCache to allocate the BO set for that bo_key. + + Phase 2 (inner loop over baton_links): _share_bo aliases the producer's + output BO into the consumer's input BO slot so that both operations refer + to the same xrt.bo object. + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + for li, layer_weights in enumerate(weights_per_layer): + # --- Phase 1: allocate BOs for every sub-launch --- + for sub in spec.sub_launches: + out_shape = _output_shape_for(spec.name, sub.name, config) + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = np.zeros(out_shape, dtype=bfloat16) + continue + if ( + sub.weight_slot_in_standalone is not None + and slot == sub.weight_slot_in_standalone + ): + # Use the actual weight so the BO is populated from the start. + w = _static_input_for(spec.name, sub.name, slot, layer_weights) + assert w is not None, ( + f"[cell_c preload] _static_input_for returned None for " + f"{spec.name}/{sub.name} slot={slot}" + ) + args[slot] = w + continue + # Activation or baton-fed slot: correctly-sized dummy zeros. + args[slot] = np.zeros( + _slot_shape_for(spec.name, sub.name, slot, config), dtype=bfloat16 + ) + + static_idx = ( + {sub.weight_slot_in_standalone} + if sub.weight_slot_in_standalone is not None + else set() + ) + kernel_name = f"{spec.name}__{sub.name}" + bo_key = f"C_{spec.name}_{sub.name}_L{li}" + + cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[sub.output_slot_in_standalone], + static_input_indices=static_idx, + bo_key=bo_key, + ) + + # --- Phase 2: alias BOs per baton_links --- + for link in spec.baton_links: + producer = spec.sub_launches[link.producer_idx] + consumer = spec.sub_launches[link.consumer_idx] + _share_bo( + cache, + f"C_{spec.name}_{producer.name}_L{li}", + link.producer_out_slot, + f"C_{spec.name}_{consumer.name}_L{li}", + link.consumer_in_slot, + ) + + +# --------------------------------------------------------------------------- +# Timed run phase +# --------------------------------------------------------------------------- + + +def run_cell_c(cache, spec, layer_inputs, config, backend_preset, layer_idx=0): + """Run all spec.sub_launches sequentially with pre-loaded weight BOs and + shared intermediate BOs (baton-pass). + + Differences from Cell B: + - bo_key uses "C_" prefix (matches preload). + - intermediate_indices is set for each sub-launch based on baton_links: + * producer's output slot -> kernel overwrites it; don't host-write + * consumer's input slot -> aliased to upstream BO; don't host-write + + For baton-fed input slots the numpy arg is np.zeros (placeholder); bytes + are skipped because the slot is in intermediate_indices. + + Args: + cache: KernelCache with manifested artifacts (preload must have run). + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + layer_inputs: dict of numpy arrays keyed by semantic name. + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + layer_idx: layer index used to select the right pre-loaded BO set. + + Returns: + dict keyed by sub.name -> 1D flat numpy array of that sub-launch's + output, plus "_wall_s" for total wall time. + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + out_shape = _output_shape_for(spec.name, sub.name, config) + + # Build the 3-arg list. + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = np.zeros(out_shape, dtype=bfloat16) + continue + + # Try static (weight/LUT/layer-level) lookup first. + v = _static_input_for(spec.name, sub.name, slot, layer_inputs) + if v is not None: + args[slot] = v + continue + + # Baton-fed slot: host won't write it (intermediate_indices); use + # a correctly-sized zero placeholder so the array shape is valid. + args[slot] = np.zeros( + _slot_shape_for(spec.name, sub.name, slot, config), dtype=bfloat16 + ) + + intermediate_idx = _intermediate_slots_for_sub(spec, idx) + static_idx = ( + {sub.weight_slot_in_standalone} + if sub.weight_slot_in_standalone is not None + else set() + ) + + kernel_name = f"{spec.name}__{sub.name}" + bo_key = f"C_{spec.name}_{sub.name}_L{layer_idx}" + + result = cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[sub.output_slot_in_standalone], + static_input_indices=static_idx, + intermediate_indices=intermediate_idx, + bo_key=bo_key, + ) + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results diff --git a/programming_examples/llama32_1b/ablation/decode/cells/cell_d_merged.py b/programming_examples/llama32_1b/ablation/decode/cells/cell_d_merged.py new file mode 100644 index 000000000..c17af76d5 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/cell_d_merged.py @@ -0,0 +1,215 @@ +"""Cell D — production-merged decode ELFs. + +Compiles and invokes: +- rms_gemv_rope.elf (6 stitched launches in one xrt.run) +- o_gemv_ffn.elf (8 stitched launches in one xrt.run) + +Mirrors production llama32_1b_inference.py decode dispatch (static_input_indices ++ bo_key per layer). The lm_head_gemv ELF is compiled here too but invoked via +cells.lm_head_const (held INVARIANT across cells). + +Three public functions: +- compile_cell_d(cache, config): compile rms_gemv_rope + o_gemv_ffn ELFs. +- preload_cell_d(cache, weights_per_layer, rope_lut_pos_q, rope_lut_pos_k, config): + one-time per-layer BO + weight preload. +- run_rms_gemv_rope_d(cache, layer_inputs, layer_idx) → dict. +- run_o_gemv_ffn_d(cache, layer_inputs, layer_idx) → dict. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RGR_BACKEND, OGF_BACKEND + +# Production decode static_input_indices (mirrors llama32_1b_inference.py preload): +# rms_gemv_rope: {1, 3, 5, 7} = norm_w, wq, wk, wv (LUTs at 9, 10 NOT static) +# o_gemv_ffn: {0, 5, 7, 9, 12} = wo, ffn_norm_w, w_gate, w_up, w_down +_RGR_STATIC = {1, 3, 5, 7} +_RGR_INTERMEDIATE = {2, 4, 6, 8, 11, 12} +_OGF_STATIC = {0, 5, 7, 9, 12} +_OGF_INTERMEDIATE = {2, 4, 6, 8, 10, 11, 13, 14} + + +def compile_cell_d(cache: KernelCache, config): + """Compile production rms_gemv_rope and o_gemv_ffn ELFs (one-time).""" + if "rms_gemv_rope" not in cache.artifacts: + from multi_launch_builder.rms_gemv_rope_multi import build_rms_gemv_rope_module + + mod = build_rms_gemv_rope_module( + emb_dim=config["emb_dim"], + kv_dim=config["kv_dim"], + n_heads=config["n_heads"], + n_kv_heads=config["n_kv_heads"], + head_dim=config["head_dim"], + ) + cache.compile_and_cache( + "rms_gemv_rope", + mod, + {**RGR_BACKEND, "verbose": getattr(cache, "verbose", False)}, + ) + cache._save_manifest() + + if "o_gemv_ffn" not in cache.artifacts: + from multi_launch_builder.o_gemv_ffn_multi import build_o_gemv_ffn_module + + mod = build_o_gemv_ffn_module( + emb_dim=config["emb_dim"], + hidden_dim=config["hidden_dim"], + ) + cache.compile_and_cache( + "o_gemv_ffn", + mod, + {**OGF_BACKEND, "verbose": getattr(cache, "verbose", False)}, + ) + cache._save_manifest() + + +def preload_cell_d(cache, weights_per_layer, lut_q, lut_k, config): + """Pre-load per-layer weights into per-layer BOs. + + Mirrors production llama32_1b_inference.py preload pattern. After this, + each layer's BO set holds its weights resident on the NPU; subsequent + run_*_d calls only upload activations (slot 0/1) and LUTs (9, 10). + """ + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + + for layer_idx, w in enumerate(weights_per_layer): + # rms_gemv_rope: 13 args + cache.load_and_run( + "rms_gemv_rope", + RGR_BACKEND, + np.zeros(emb, dtype=bfloat16), # 0 x_in (placeholder) + w["norm_w"], # 1 (static) + np.zeros(emb, dtype=bfloat16), # 2 normed + w["wq"], # 3 (static) + np.zeros(emb, dtype=bfloat16), # 4 q + w["wk"], # 5 (static) + np.zeros(kv, dtype=bfloat16), # 6 k + w["wv"], # 7 (static) + np.zeros(kv, dtype=bfloat16), # 8 v + lut_q, # 9 (NOT static) + lut_k, # 10 (NOT static) + np.zeros(emb, dtype=bfloat16), # 11 q_roped + np.zeros(kv, dtype=bfloat16), # 12 k_roped + output_indices=[8, 11, 12], + static_input_indices=_RGR_STATIC, + intermediate_indices=_RGR_INTERMEDIATE, + bo_key=f"D_rms_gemv_rope_L{layer_idx}", + ) + + # o_gemv_ffn: 15 args + cache.load_and_run( + "o_gemv_ffn", + OGF_BACKEND, + w["wo"], # 0 (static) + np.zeros(emb, dtype=bfloat16), # 1 attn_out (placeholder) + np.zeros(emb, dtype=bfloat16), # 2 proj + np.zeros(emb, dtype=bfloat16), # 3 x_residual (placeholder) + np.zeros(emb, dtype=bfloat16), # 4 res1 + w["ffn_norm_w"], # 5 (static) + np.zeros(emb, dtype=bfloat16), # 6 normed2 + w["w_gate"], # 7 (static) + np.zeros(hid, dtype=bfloat16), # 8 gate + w["w_up"], # 9 (static) + np.zeros(hid, dtype=bfloat16), # 10 up + np.zeros(hid, dtype=bfloat16), # 11 swiglu + w["w_down"], # 12 (static) + np.zeros(emb, dtype=bfloat16), # 13 down + np.zeros(emb, dtype=bfloat16), # 14 output + output_indices=[14], + static_input_indices=_OGF_STATIC, + intermediate_indices=_OGF_INTERMEDIATE, + bo_key=f"D_o_gemv_ffn_L{layer_idx}", + ) + + +def run_rms_gemv_rope_d(cache, layer_inputs, layer_idx=0): + """Production merged dispatch — 6 stitched launches in 1 xrt.run. + + layer_inputs keys: x_in, norm_w, wq, wk, wv, lut_q, lut_k. + Returns dict with normed, q, k, v, q_roped, k_roped, _wall_s. + """ + emb = layer_inputs["x_in"].shape[-1] + # Determine kv_dim from wk shape (W is at slot 0 of GEMV, shape [kv, emb]) + kv = layer_inputs["wk"].shape[0] + + args = [ + layer_inputs["x_in"].astype(bfloat16).flatten(), # 0 + layer_inputs["norm_w"].astype(bfloat16), # 1 (static) + np.zeros(emb, dtype=bfloat16), # 2 normed + layer_inputs["wq"], # 3 (static) + np.zeros(emb, dtype=bfloat16), # 4 q + layer_inputs["wk"], # 5 (static) + np.zeros(kv, dtype=bfloat16), # 6 k + layer_inputs["wv"], # 7 (static) + np.zeros(kv, dtype=bfloat16), # 8 v + layer_inputs["lut_q"].astype(bfloat16), # 9 + layer_inputs["lut_k"].astype(bfloat16), # 10 + np.zeros(emb, dtype=bfloat16), # 11 q_roped + np.zeros(kv, dtype=bfloat16), # 12 k_roped + ] + t0 = time.perf_counter() + out = cache.load_and_run( + "rms_gemv_rope", + RGR_BACKEND, + *args, + output_indices=[2, 4, 6, 8, 11, 12], + static_input_indices=_RGR_STATIC, + intermediate_indices=_RGR_INTERMEDIATE, + bo_key=f"D_rms_gemv_rope_L{layer_idx}", + ) + elapsed = time.perf_counter() - t0 + return { + "normed": out[2], + "q": out[4], + "k": out[6], + "v": out[8], + "q_roped": out[11], + "k_roped": out[12], + "_wall_s": elapsed, + } + + +def run_o_gemv_ffn_d(cache, layer_inputs, layer_idx=0): + """Production merged dispatch — 8 stitched launches in 1 xrt.run. + + layer_inputs keys: wo, attn_out, x_residual, ffn_norm_w, w_gate, w_up, w_down. + Returns dict with output, _wall_s. + """ + emb = layer_inputs["attn_out"].shape[-1] + hid = layer_inputs["w_gate"].shape[0] + + args = [ + layer_inputs["wo"], # 0 (static) + layer_inputs["attn_out"].astype(bfloat16).flatten(), # 1 + np.zeros(emb, dtype=bfloat16), # 2 proj + layer_inputs["x_residual"].astype(bfloat16).flatten(), # 3 + np.zeros(emb, dtype=bfloat16), # 4 res1 + layer_inputs["ffn_norm_w"].astype(bfloat16), # 5 (static) + np.zeros(emb, dtype=bfloat16), # 6 normed2 + layer_inputs["w_gate"], # 7 (static) + np.zeros(hid, dtype=bfloat16), # 8 gate + layer_inputs["w_up"], # 9 (static) + np.zeros(hid, dtype=bfloat16), # 10 up + np.zeros(hid, dtype=bfloat16), # 11 swiglu + layer_inputs["w_down"], # 12 (static) + np.zeros(emb, dtype=bfloat16), # 13 down + np.zeros(emb, dtype=bfloat16), # 14 output + ] + t0 = time.perf_counter() + out = cache.load_and_run( + "o_gemv_ffn", + OGF_BACKEND, + *args, + output_indices=[14], + static_input_indices=_OGF_STATIC, + intermediate_indices=_OGF_INTERMEDIATE, + bo_key=f"D_o_gemv_ffn_L{layer_idx}", + ) + elapsed = time.perf_counter() - t0 + return {"output": out[14], "_wall_s": elapsed} diff --git a/programming_examples/llama32_1b/ablation/decode/cells/common.py b/programming_examples/llama32_1b/ablation/decode/cells/common.py new file mode 100644 index 000000000..6d276fb30 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/common.py @@ -0,0 +1,15 @@ +"""Re-export Plan 1's common helpers.""" + +from prefill.cells.common import ( + compile_standalone_kernels, + _share_bo, + _extract_public_func_name, + standalone_backend_kwargs, +) + +__all__ = [ + "compile_standalone_kernels", + "_share_bo", + "_extract_public_func_name", + "standalone_backend_kwargs", +] diff --git a/programming_examples/llama32_1b/ablation/decode/cells/decode_attn_const.py b/programming_examples/llama32_1b/ablation/decode/cells/decode_attn_const.py new file mode 100644 index 000000000..57efa3a3b --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/decode_attn_const.py @@ -0,0 +1,51 @@ +"""Decode CPU attention invariant runner. + +`decode_attention_cpu` runs on the CPU and is structurally identical across +all 4 cells (it's not subject to NPU dispatch optimizations). This module +wraps the production function from llama32_1b_decode.py so every cell calls +exactly the same Python. + +Returns (attn_out, elapsed_seconds). The elapsed_seconds is reported separately +in the per-token results table as the "CPU attention floor" — analogous to how +Plan 1 reports FA's invariant per-layer cost. + +Note: production `decode_attention_cpu` reads `k_cache[:, :current_pos+1, :]` +internally, so the caller MUST have written the new k/v at slot `current_pos` +before calling this function. The KV-cache write happens in the per-token loop +(cells/per_token_loop.py) right after rms_gemv_rope returns, before this call. +""" + +import time + +from llama32_1b_decode import decode_attention_cpu + + +def run_decode_attention( + q_roped, k_cache_layer, v_cache_layer, current_pos, n_heads, n_kv_heads, head_dim +): + """Invoke the production decode_attention_cpu and time it. + + Args: + q_roped: (emb_dim,) bf16 — current token's RoPE'd query + k_cache_layer: (n_kv_heads, max_seq, head_dim) bf16 — this layer's K cache + (must already have new k written at slot current_pos) + v_cache_layer: same shape — this layer's V cache (with new v at current_pos) + current_pos: int — the current token's slot index + n_heads, n_kv_heads, head_dim: ints — model config + + Returns: + attn_out: (emb_dim,) bf16 + elapsed: float — wall time of the CPU attention call (seconds) + """ + t0 = time.perf_counter() + attn_out = decode_attention_cpu( + q_roped, + k_cache_layer, + v_cache_layer, + current_pos, + n_heads, + n_kv_heads, + head_dim, + ) + elapsed = time.perf_counter() - t0 + return attn_out, elapsed diff --git a/programming_examples/llama32_1b/ablation/decode/cells/kv_cache.py b/programming_examples/llama32_1b/ablation/decode/cells/kv_cache.py new file mode 100644 index 000000000..f362b4b33 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/kv_cache.py @@ -0,0 +1,57 @@ +"""KV cache state management for the per-token timed loop. + +Two functions: +- build_initial_kv_cache(config, prompt_len, seed): + Deterministic synthetic pre-fill of `prompt_len` positions for ALL layers. + Returns dict {k_cache, v_cache, current_pos}. The cache shape is + (n_layers, n_kv_heads, max_seq, head_dim) bf16. + +- reset_position(cache, pos): + Zero out the K/V cache slots at position `pos` for ALL layers. + Used between trials to ensure each trial starts from the SAME state + (the pre-filled prompt without the previously-generated token's k/v). +""" + +import numpy as np +from ml_dtypes import bfloat16 + + +def build_initial_kv_cache(config, prompt_len, seed): + """Deterministic synthetic pre-fill of `prompt_len` cache positions. + + config keys required: n_layers, n_kv_heads, head_dim, max_seq + + Returns dict with: + k_cache: (n_layers, n_kv_heads, max_seq, head_dim) bf16 + v_cache: same shape + current_pos: int = prompt_len (next slot to write) + """ + rng = np.random.default_rng(seed) + shape = ( + config["n_layers"], + config["n_kv_heads"], + config["max_seq"], + config["head_dim"], + ) + k = np.zeros(shape, dtype=bfloat16) + v = np.zeros(shape, dtype=bfloat16) + pre_shape = ( + config["n_layers"], + config["n_kv_heads"], + prompt_len, + config["head_dim"], + ) + k[:, :, :prompt_len, :] = (rng.standard_normal(pre_shape) * 0.5).astype(bfloat16) + v[:, :, :prompt_len, :] = (rng.standard_normal(pre_shape) * 0.5).astype(bfloat16) + return {"k_cache": k, "v_cache": v, "current_pos": prompt_len} + + +def reset_position(cache, pos): + """Zero out the K/V cache slots at `pos` for ALL layers. + + Called between timing trials so each trial sees the same initial state + (the pre-filled prompt's positions [0:prompt_len] but no new-token entry + at `pos = prompt_len`). + """ + cache["k_cache"][:, :, pos, :] = 0 + cache["v_cache"][:, :, pos, :] = 0 diff --git a/programming_examples/llama32_1b/ablation/decode/cells/lm_head_const.py b/programming_examples/llama32_1b/ablation/decode/cells/lm_head_const.py new file mode 100644 index 000000000..c0b8cf25d --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/lm_head_const.py @@ -0,0 +1,99 @@ +"""LM head invariant runner — production-merged 8-partition GEMV in every cell. + +The LM head (`lm_head_gemv.elf`) is structurally one merged ELF in production +and is held INVARIANT across the 4 cells of Plan 2 (rationale: see spec §4 — +mirrors Plan 1's treatment of FlashAttention). Reporting it as a separate +"fixed cost per token" line keeps the cells comparable on the parts that +DO change. + +Three functions: +- compile_lm_head(cache, config): compiles the production lm_head_gemv ELF. +- preload_lm_head(cache, lm_weight_parts): one-time pre-upload of the 8 + partition weights into BOs (skipped on subsequent calls via static_input_indices). +- run_lm_head(cache, x_normed, vocab_size): invoke + concatenate partition + outputs + argmax → returns (next_token_id, elapsed_seconds). +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.backend_presets import LM_GEMV_BACKEND + +_LM_N_PART = 16384 +_LM_N_PARTITIONS = 8 + + +def compile_lm_head(cache, config): + """Compile the production lm_head_gemv ELF (one-time).""" + if "lm_head_gemv" in cache.artifacts: + return + from multi_launch_builder.lm_head_gemv_multi import build_lm_head_gemv_module + + mod = build_lm_head_gemv_module( + emb_dim=config["emb_dim"], + n_partitions=_LM_N_PARTITIONS, + n_part=_LM_N_PART, + ) + cache.compile_and_cache( + "lm_head_gemv", + mod, + {**LM_GEMV_BACKEND, "verbose": getattr(cache, "verbose", False)}, + ) + cache._save_manifest() + + +def preload_lm_head(cache, lm_weight_parts, config): + """One-time pre-upload of LM head partition weights. + + `lm_weight_parts`: list of 8 numpy arrays, each shape (_LM_N_PART, emb_dim). + The first call materializes BOs and writes weights; subsequent run_lm_head + calls skip weight upload via static_input_indices. + """ + emb_dim = config["emb_dim"] + inputs = [np.zeros(emb_dim, dtype=bfloat16)] + for p in range(_LM_N_PARTITIONS): + inputs.append(lm_weight_parts[p]) + inputs.append(np.zeros(_LM_N_PART, dtype=bfloat16)) + cache.load_and_run( + "lm_head_gemv", + LM_GEMV_BACKEND, + *inputs, + output_indices=[2 + 2 * p for p in range(_LM_N_PARTITIONS)], + static_input_indices={1 + 2 * p for p in range(_LM_N_PARTITIONS)}, + intermediate_indices={2 + 2 * p for p in range(_LM_N_PARTITIONS)}, + ) + + +def run_lm_head(cache, x_normed, vocab_size, config): + """Run LM head; return (next_token_id, elapsed_seconds). + + `x_normed`: (emb_dim,) bf16 — the final RMSNorm output for the current token. + `vocab_size`: int — usually 128256 for Llama-3.2-1B. + + Mirrors the production code in llama32_1b_inference.py:434-446. + """ + emb_dim = config["emb_dim"] + inputs = [x_normed.astype(bfloat16).flatten()] + for p in range(_LM_N_PARTITIONS): + # Placeholder weight — actual weight in BOs from preload + static_input_indices. + inputs.append(np.zeros((_LM_N_PART, emb_dim), dtype=bfloat16)) + inputs.append(np.zeros(_LM_N_PART, dtype=bfloat16)) + + t0 = time.perf_counter() + results = cache.load_and_run( + "lm_head_gemv", + LM_GEMV_BACKEND, + *inputs, + output_indices=[2 + 2 * p for p in range(_LM_N_PARTITIONS)], + static_input_indices={1 + 2 * p for p in range(_LM_N_PARTITIONS)}, + intermediate_indices={2 + 2 * p for p in range(_LM_N_PARTITIONS)}, + ) + elapsed = time.perf_counter() - t0 + + logits = np.concatenate( + [results[2 + 2 * p] for p in range(_LM_N_PARTITIONS)], axis=0 + )[:vocab_size] + next_token = int(np.argmax(logits)) + return next_token, elapsed diff --git a/programming_examples/llama32_1b/ablation/decode/cells/per_token_loop.py b/programming_examples/llama32_1b/ablation/decode/cells/per_token_loop.py new file mode 100644 index 000000000..8abaf5c44 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/cells/per_token_loop.py @@ -0,0 +1,150 @@ +"""Per-token decode loop wrapper — the end-to-end timed unit for Plan 2. + +Generates ONE decode token at a fixed `current_pos` from a pre-filled KV cache. +The cell-specific dispatch is injected via `run_rms_gemv_rope` and +`run_o_gemv_ffn` function arguments so the same wrapper works for all 4 cells. + +For each of the 16 layers: + 1. NPU rms_gemv_rope (cell-specific) → q_roped, k_roped, v + 2. Write k_roped, v into KV cache at current_pos + 3. CPU decode_attention_cpu (invariant) → attn_out + 4. NPU o_gemv_ffn (cell-specific) → next-layer activation + +After 16 layers: + 5. CPU final RMSNorm on the running hidden state (single row) + 6. NPU lm_head_gemv (invariant) → logits → argmax → next_token + +The `layer_inputs_per_layer` list contains per-layer weight bundles +(rms_gemv_rope's: norm_w, wq, wk, wv, lut_q, lut_k; o_gemv_ffn's: wo, +ffn_norm_w, w_gate, w_up, w_down). The cell-specific runners are +responsible for assembling these into the kernel-group's expected +argument order. + +Returns a dict with per-stage wall times for downstream attribution. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.decode_attn_const import run_decode_attention +from cells.lm_head_const import run_lm_head + + +def _final_rms_norm_cpu(x_bf16, weight_bf16, eps=1e-5): + """Single-row RMSNorm on the final hidden state (mirrors production). + + x: (emb_dim,) bf16; weight: (emb_dim,) bf16. Returns (emb_dim,) bf16. + """ + x_f32 = x_bf16.astype(np.float32) + w_f32 = weight_bf16.astype(np.float32) + rms = np.sqrt(np.mean(x_f32 * x_f32) + eps) + return ((x_f32 / rms) * w_f32).astype(bfloat16) + + +def run_one_decode_token( + cache, + config, + kv_cache, + layer_inputs_per_layer, + final_norm_w, + lm_weight_parts, + initial_x_decode, + current_pos, + run_rms_gemv_rope, + run_o_gemv_ffn, +): + """Generate ONE decode token end-to-end. THIS IS THE TIMED UNIT. + + Args: + cache: shared KernelCache with all ELFs compiled + preloaded + config: dict with emb_dim, n_heads, n_kv_heads, head_dim, n_layers, vocab_size + kv_cache: dict from build_initial_kv_cache (mutated in-place) + layer_inputs_per_layer: list of N dicts, one per layer, with weight tensors + final_norm_w: (emb_dim,) bf16 — final RMSNorm weight + lm_weight_parts: list of 8 (16384, emb_dim) arrays — LM head partitions + initial_x_decode: (emb_dim,) bf16 — the token's embedding + current_pos: int — the slot in KV cache to write the new k/v + run_rms_gemv_rope: callable(cache, layer_inputs, layer_idx) -> dict with + q_roped, k_roped, v, _wall_s + run_o_gemv_ffn: callable(cache, layer_inputs, layer_idx) -> dict with + output, _wall_s + + Returns dict with: + next_token: int + per_layer_npu_wall_s: list of N floats (rms_gemv_rope + o_gemv_ffn per layer) + per_layer_rms_gemv_rope_wall_s: list of N floats + per_layer_o_gemv_ffn_wall_s: list of N floats + cpu_attn_wall_s: float (sum across N layers) + lm_head_wall_s: float + total_wall_s: float (everything inside the timer) + """ + n_layers = config["n_layers"] + n_heads = config["n_heads"] + n_kv_heads = config["n_kv_heads"] + head_dim = config["head_dim"] + vocab_size = config["vocab_size"] + + per_layer_rg = [] + per_layer_of = [] + cpu_attn_total = 0.0 + x = initial_x_decode + + t_total_start = time.perf_counter() + for L in range(n_layers): + layer_in = dict(layer_inputs_per_layer[L]) + layer_in["x_in"] = x + layer_in["current_pos"] = current_pos + + # 1. rms_gemv_rope (NPU, cell-specific) + rg_out = run_rms_gemv_rope(cache, layer_in, layer_idx=L) + per_layer_rg.append(rg_out["_wall_s"]) + + q_roped = rg_out["q_roped"].astype(bfloat16) + k_roped = rg_out["k_roped"].astype(bfloat16) + v = rg_out["v"].astype(bfloat16) + + # 2. KV cache write (CPU) + kv_cache["k_cache"][L, :, current_pos, :] = k_roped.reshape( + n_kv_heads, head_dim + ) + kv_cache["v_cache"][L, :, current_pos, :] = v.reshape(n_kv_heads, head_dim) + + # 3. CPU decode attention (invariant) + attn_out, attn_t = run_decode_attention( + q_roped.flatten(), + kv_cache["k_cache"][L], + kv_cache["v_cache"][L], + current_pos, + n_heads, + n_kv_heads, + head_dim, + ) + cpu_attn_total += attn_t + + # 4. o_gemv_ffn (NPU, cell-specific) + of_in = dict(layer_in) + of_in["attn_out"] = attn_out.astype(bfloat16) + of_in["x_residual"] = x # the activation entering THIS layer + of_out = run_o_gemv_ffn(cache, of_in, layer_idx=L) + per_layer_of.append(of_out["_wall_s"]) + + x = of_out["output"].astype(bfloat16).flatten() + + # 5. Final RMSNorm (CPU, single row) + x_normed = _final_rms_norm_cpu(x, final_norm_w) + + # 6. LM head (NPU, invariant) + argmax + next_token, lm_t = run_lm_head(cache, x_normed, vocab_size, config) + + total_wall = time.perf_counter() - t_total_start + return { + "next_token": next_token, + "per_layer_npu_wall_s": [a + b for a, b in zip(per_layer_rg, per_layer_of)], + "per_layer_rms_gemv_rope_wall_s": per_layer_rg, + "per_layer_o_gemv_ffn_wall_s": per_layer_of, + "cpu_attn_wall_s": cpu_attn_total, + "lm_head_wall_s": lm_t, + "total_wall_s": total_wall, + } diff --git a/programming_examples/llama32_1b/ablation/decode/golden/__init__.py b/programming_examples/llama32_1b/ablation/decode/golden/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/decode/golden/golden_meta.json b/programming_examples/llama32_1b/ablation/decode/golden/golden_meta.json new file mode 100644 index 000000000..f9a4a1184 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/golden/golden_meta.json @@ -0,0 +1,28 @@ +{ + "config": { + "emb_dim": 2048, + "kv_dim": 512, + "n_heads": 32, + "n_kv_heads": 8, + "head_dim": 64, + "hidden_dim": 8192, + "n_layers": 16, + "max_seq": 2048, + "vocab_size": 128256 + }, + "prompt_len": 7, + "current_pos": 7, + "seed": 42, + "layer_idx": 0, + "rms_gemv_rope_outputs": { + "normed": "a97e976415483974", + "q": "8eb0329b8a682062", + "k": "858e3700aa681e8f", + "v": "3614ed9453d88a31", + "q_roped": "206a8aedfaf6fc25", + "k_roped": "a30ed65232069ab6" + }, + "o_gemv_ffn_outputs": { + "output": "0f3cd9c0cfc685bb" + } +} \ No newline at end of file diff --git a/programming_examples/llama32_1b/ablation/decode/golden/golden_o_gemv_ffn_decode.npz b/programming_examples/llama32_1b/ablation/decode/golden/golden_o_gemv_ffn_decode.npz new file mode 100644 index 000000000..37d5357d7 Binary files /dev/null and b/programming_examples/llama32_1b/ablation/decode/golden/golden_o_gemv_ffn_decode.npz differ diff --git a/programming_examples/llama32_1b/ablation/decode/golden/golden_rms_gemv_rope_decode.npz b/programming_examples/llama32_1b/ablation/decode/golden/golden_rms_gemv_rope_decode.npz new file mode 100644 index 000000000..278b4a177 Binary files /dev/null and b/programming_examples/llama32_1b/ablation/decode/golden/golden_rms_gemv_rope_decode.npz differ diff --git a/programming_examples/llama32_1b/ablation/decode/golden/regen_golden.py b/programming_examples/llama32_1b/ablation/decode/golden/regen_golden.py new file mode 100644 index 000000000..1c6cf3251 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/golden/regen_golden.py @@ -0,0 +1,200 @@ +"""Regenerate decode golden fixtures by running Cell D for layer 0 at current_pos=7. + +Uses deterministic synthetic inputs (numpy seed=42). +Outputs: + golden/golden_rms_gemv_rope_decode.npz + golden/golden_o_gemv_ffn_decode.npz + golden/golden_meta.json + +Usage: + flock -x -w 1800 /tmp/mlir-air-npu.lock python3 golden/regen_golden.py +""" + +import hashlib +import json +import os +import sys + +import numpy as np +from ml_dtypes import bfloat16 + +# sys.path setup — make decode/, ablation/, llama32_1b/, programming_examples/ importable +_THIS = os.path.dirname(os.path.abspath(__file__)) +_DECODE = os.path.dirname(_THIS) +_ABLATION = os.path.dirname(_DECODE) +_LLAMA = os.path.dirname(_ABLATION) +_PE = os.path.dirname(_LLAMA) +for p in (_PE, _LLAMA, _ABLATION, os.path.join(_ABLATION, "prefill"), _DECODE): + if p not in sys.path: + sys.path.insert(0, p) + +from kernel_builder.cache import KernelCache +from cells.cell_d_merged import ( + compile_cell_d, + preload_cell_d, + run_rms_gemv_rope_d, + run_o_gemv_ffn_d, +) + +CONFIG = { + "seq_len": 1, # decode is single-token; seq_len present for shape-helper compatibility + "emb_dim": 2048, + "kv_dim": 512, + "n_heads": 32, + "n_kv_heads": 8, + "head_dim": 64, + "hidden_dim": 8192, + "n_layers": 16, + "max_seq": 2048, + "vocab_size": 128256, +} + +PROMPT_LEN = 7 +CURRENT_POS = 7 # decode generates the token at position 7 (after a 7-token prefill) +SEED = 42 + + +def synthetic_layer_weights(layer_idx, config, seed): + """Per-layer weights — already in production-decode transposed shape. + + GEMV convention: W at slot 0 with shape (out_dim, in_dim). HuggingFace + storage uses (out, in) too, but production pre-transposes; for synthetic + inputs we just generate at the production shape directly. + """ + rng = np.random.default_rng(seed + layer_idx) + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + return { + "norm_w": rng.standard_normal(emb).astype(bfloat16), + "wq": (rng.standard_normal((emb, emb)) * 0.02).astype(bfloat16), + "wk": (rng.standard_normal((kv, emb)) * 0.02).astype(bfloat16), + "wv": (rng.standard_normal((kv, emb)) * 0.02).astype(bfloat16), + "wo": (rng.standard_normal((emb, emb)) * 0.02).astype(bfloat16), + "ffn_norm_w": rng.standard_normal(emb).astype(bfloat16), + "w_gate": (rng.standard_normal((hid, emb)) * 0.02).astype(bfloat16), + "w_up": (rng.standard_normal((hid, emb)) * 0.02).astype(bfloat16), + "w_down": (rng.standard_normal((emb, hid)) * 0.02).astype(bfloat16), + } + + +def synthetic_x_in(config, seed): + """The token's embedding entering layer 0.""" + rng = np.random.default_rng(seed + 9999) + return rng.standard_normal(config["emb_dim"]).astype(bfloat16) + + +def synthetic_lut(config, seed): + """Synthetic RoPE LUT slice at the timed current_pos (constant across trials).""" + rng = np.random.default_rng(seed + 8888) + emb = config["emb_dim"] + kv = config["kv_dim"] + return { + "lut_q": rng.standard_normal(emb).astype(bfloat16), + "lut_k": rng.standard_normal(kv).astype(bfloat16), + } + + +def synthetic_attn_out(config, seed): + """Synthetic post-attention activation entering o_gemv_ffn. + + For golden generation we don't actually run CPU attention — we just need + a deterministic byte-stable input for the o_gemv_ffn golden. The validation + gate compares Cell D against this golden in isolation; what feeds o_gemv_ffn + in actual inference is decode_attention_cpu(q_roped, k/v cache, ...) but that + data flow is exercised by the per-token loop test, not by this golden. + """ + rng = np.random.default_rng(seed + 7777) + return rng.standard_normal(config["emb_dim"]).astype(bfloat16) + + +def main(): + print("=" * 60) + print("Plan 2 (full decode) golden regeneration") + print(f" current_pos={CURRENT_POS}, prompt_len={PROMPT_LEN}, seed={SEED}") + print("=" * 60) + + cache_dir = os.path.join(_DECODE, "build") + os.makedirs(cache_dir, exist_ok=True) + cache = KernelCache(cache_dir=cache_dir, verbose=True) + cache.load_manifest() + + # 1. Compile both ELFs + print("\n[1/5] Compiling Cell D ELFs (rms_gemv_rope + o_gemv_ffn)...") + compile_cell_d(cache, CONFIG) + + # 2. Generate synthetic per-layer weights (just layer 0 for goldens) + print("\n[2/5] Generating synthetic weights for layer 0 (seed=42)...") + weights_layer0 = synthetic_layer_weights(layer_idx=0, config=CONFIG, seed=SEED) + lut = synthetic_lut(CONFIG, SEED) + x_in = synthetic_x_in(CONFIG, SEED) + attn_out_synth = synthetic_attn_out(CONFIG, SEED) + + # 3. Pre-load layer 0 weights into Cell D's BOs + print("\n[3/5] Pre-loading layer 0 weights into Cell D BOs...") + preload_cell_d(cache, [weights_layer0], lut["lut_q"], lut["lut_k"], CONFIG) + + # 4. Run rms_gemv_rope Cell D, capture outputs as golden + print("\n[4/5] Running rms_gemv_rope (Cell D) → golden_rms_gemv_rope_decode.npz") + rg_inputs = { + "x_in": x_in, + "norm_w": weights_layer0["norm_w"], + "wq": weights_layer0["wq"], + "wk": weights_layer0["wk"], + "wv": weights_layer0["wv"], + "lut_q": lut["lut_q"], + "lut_k": lut["lut_k"], + } + rg_out = run_rms_gemv_rope_d(cache, rg_inputs, layer_idx=0) + rg_path = os.path.join(_THIS, "golden_rms_gemv_rope_decode.npz") + np.savez( + rg_path, + normed=rg_out["normed"], + q=rg_out["q"], + k=rg_out["k"], + v=rg_out["v"], + q_roped=rg_out["q_roped"], + k_roped=rg_out["k_roped"], + ) + print(f" → wrote {rg_path} ({os.path.getsize(rg_path)} bytes)") + + # 5. Run o_gemv_ffn Cell D with synthetic attn_out, capture output as golden + print("\n[5/5] Running o_gemv_ffn (Cell D) → golden_o_gemv_ffn_decode.npz") + of_inputs = { + "wo": weights_layer0["wo"], + "attn_out": attn_out_synth, + "x_residual": x_in, + "ffn_norm_w": weights_layer0["ffn_norm_w"], + "w_gate": weights_layer0["w_gate"], + "w_up": weights_layer0["w_up"], + "w_down": weights_layer0["w_down"], + } + of_out = run_o_gemv_ffn_d(cache, of_inputs, layer_idx=0) + of_path = os.path.join(_THIS, "golden_o_gemv_ffn_decode.npz") + np.savez(of_path, output=of_out["output"]) + print(f" → wrote {of_path} ({os.path.getsize(of_path)} bytes)") + + # Meta JSON + def _h(arr): + return hashlib.sha256(arr.tobytes()).hexdigest()[:16] + + meta = { + "config": CONFIG, + "prompt_len": PROMPT_LEN, + "current_pos": CURRENT_POS, + "seed": SEED, + "layer_idx": 0, + "rms_gemv_rope_outputs": { + k: _h(v) for k, v in rg_out.items() if not k.startswith("_") + }, + "o_gemv_ffn_outputs": {"output": _h(of_out["output"])}, + } + meta_path = os.path.join(_THIS, "golden_meta.json") + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + print(f" → wrote {meta_path}") + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/decode/run_ablation.py b/programming_examples/llama32_1b/ablation/decode/run_ablation.py new file mode 100644 index 000000000..f2b0a45b3 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/run_ablation.py @@ -0,0 +1,415 @@ +"""Run the Plan 2 (full decode) 4-cell ablation. + +Per cell: + - Compile (idempotent, skipped if cached) + - Preload weights into per-layer BOs (Cells B/C/D; Cell A skips this) + - Validate: run rms_gemv_rope and o_gemv_ffn ONCE for layer 0 with synthetic + inputs, compare bytes to committed goldens + - 5 timed trials of per_token_loop generating ONE decode token at fixed + current_pos, drop trial 1 as warmup + - Median + (min, max) of trials 2-5 + +Per-kernel-group medians for rms_gemv_rope and o_gemv_ffn are extracted +from per_token_loop's per-layer wall arrays (medians across the 16 layers +within trial 2-5). + +Usage: + flock -x -w 1800 /tmp/mlir-air-npu.lock python3 run_ablation.py --trials 5 +""" + +import argparse +import json +import os +import sys +import time + +# sys.path setup (mirrors conftest.py) +_THIS = os.path.dirname(os.path.abspath(__file__)) +_ABLATION = os.path.dirname(_THIS) +_LLAMA = os.path.dirname(_ABLATION) +_PE = os.path.dirname(_LLAMA) +for p in (_PE, _LLAMA, _ABLATION, os.path.join(_ABLATION, "prefill")): + if p not in sys.path: + sys.path.append(p) +# Decode dir at sys.path[0] so decode/cells/ wins over prefill/cells/ +if _THIS in sys.path: + sys.path.remove(_THIS) +sys.path.insert(0, _THIS) +# Drop any stale `cells`/`specs`/`standalone_builders` modules from prior imports +for _stale in [ + m + for m in list(sys.modules) + if m.startswith(("cells", "specs", "standalone_builders")) +]: + del sys.modules[_stale] + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RGR_BACKEND, OGF_BACKEND + +from validate import GoldenMismatch, validate_against_golden +from cells import cell_a_naive, cell_b_static, cell_c_charitable, cell_d_merged +from cells.kv_cache import build_initial_kv_cache, reset_position +from cells.lm_head_const import ( + compile_lm_head, + preload_lm_head, + _LM_N_PART, + _LM_N_PARTITIONS, +) +from cells.per_token_loop import run_one_decode_token +from specs.rms_gemv_rope import SPEC as RGR_SPEC +from specs.o_gemv_ffn import SPEC as OGF_SPEC +from golden.regen_golden import ( + CONFIG, + PROMPT_LEN, + CURRENT_POS, + SEED, + synthetic_layer_weights, + synthetic_lut, + synthetic_x_in, + synthetic_attn_out, +) + +GOLDEN_DIR = os.path.join(_THIS, "golden") + + +# --------------------------------------------------------------------------- +# Cell-specific dispatch adapters for the per-token loop +# --------------------------------------------------------------------------- + + +def _wrap_rg_runner(cell, spec): + """Return a (cache, layer_inputs, layer_idx) -> dict adapter. + + Output dict normalizes sub-launch names to {normed, q, k, v, q_roped, + k_roped} for downstream consumers (per_token_loop, validation). + """ + if cell == "D": + + def _run(cache, layer_inputs, layer_idx=0): + return cell_d_merged.run_rms_gemv_rope_d(cache, layer_inputs, layer_idx) + + return _run + + if cell == "A": + runner = cell_a_naive.run_cell_a + elif cell == "B": + runner = cell_b_static.run_cell_b + elif cell == "C": + runner = cell_c_charitable.run_cell_c + else: + raise ValueError(f"unknown cell {cell!r}") + + def _run(cache, layer_inputs, layer_idx=0): + out = runner( + cache, spec, layer_inputs, CONFIG, RGR_BACKEND, layer_idx=layer_idx + ) + # Normalize keys for downstream consumers + return { + "normed": out["rmsnorm"], + "q": out["q_gemv"], + "k": out["k_gemv"], + "v": out["v_gemv"], + "q_roped": out["rope_q"], + "k_roped": out["rope_k"], + "_wall_s": out["_wall_s"], + } + + return _run + + +def _wrap_of_runner(cell, spec): + if cell == "D": + + def _run(cache, layer_inputs, layer_idx=0): + return cell_d_merged.run_o_gemv_ffn_d(cache, layer_inputs, layer_idx) + + return _run + + if cell == "A": + runner = cell_a_naive.run_cell_a + elif cell == "B": + runner = cell_b_static.run_cell_b + elif cell == "C": + runner = cell_c_charitable.run_cell_c + else: + raise ValueError(f"unknown cell {cell!r}") + + def _run(cache, layer_inputs, layer_idx=0): + out = runner( + cache, spec, layer_inputs, CONFIG, OGF_BACKEND, layer_idx=layer_idx + ) + # Cells A/B/C return all 8 sub-launch outputs; the per_token_loop + # only needs the final residual add as 'output'. + return {"output": out["add_ffn_residual"], "_wall_s": out["_wall_s"]} + + return _run + + +# --------------------------------------------------------------------------- +# Validation: run layer 0 once, compare to goldens +# --------------------------------------------------------------------------- + + +def _validate_cell(cell, cache, layer0_weights, lut, x_in, attn_out_synth): + """Run rms_gemv_rope and o_gemv_ffn for layer 0 (synthetic inputs) and + bit-exact compare to committed goldens. Raises GoldenMismatch on diff.""" + rg_runner = _wrap_rg_runner(cell, RGR_SPEC) + of_runner = _wrap_of_runner(cell, OGF_SPEC) + + rg_in = { + "x_in": x_in, + "norm_w": layer0_weights["norm_w"], + "wq": layer0_weights["wq"], + "wk": layer0_weights["wk"], + "wv": layer0_weights["wv"], + "lut_q": lut["lut_q"], + "lut_k": lut["lut_k"], + } + rg_out = rg_runner(cache, rg_in, layer_idx=0) + rg_compare = {k: rg_out[k] for k in ("normed", "q", "k", "v", "q_roped", "k_roped")} + validate_against_golden(rg_compare, GOLDEN_DIR, "golden_rms_gemv_rope_decode.npz") + + of_in = { + "wo": layer0_weights["wo"], + "attn_out": attn_out_synth, + "x_residual": x_in, + "ffn_norm_w": layer0_weights["ffn_norm_w"], + "w_gate": layer0_weights["w_gate"], + "w_up": layer0_weights["w_up"], + "w_down": layer0_weights["w_down"], + } + of_out = of_runner(cache, of_in, layer_idx=0) + of_compare = {"output": of_out["output"]} + validate_against_golden(of_compare, GOLDEN_DIR, "golden_o_gemv_ffn_decode.npz") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--trials", type=int, default=5) + ap.add_argument("--out", default=None) + args = ap.parse_args() + + cache_dir = os.path.join(_THIS, "build") + os.makedirs(cache_dir, exist_ok=True) + cache = KernelCache(cache_dir=cache_dir, verbose=False) + cache.load_manifest() + + # ------ 1. Compile all cells (idempotent) ------ + print("=== Compiling cells (idempotent) ===") + cell_a_naive.compile_cell_a(cache, RGR_SPEC, RGR_BACKEND) + cell_a_naive.compile_cell_a(cache, OGF_SPEC, OGF_BACKEND) + cell_b_static.compile_cell_b(cache, RGR_SPEC, RGR_BACKEND) + cell_b_static.compile_cell_b(cache, OGF_SPEC, OGF_BACKEND) + cell_c_charitable.compile_cell_c(cache, RGR_SPEC, RGR_BACKEND) + cell_c_charitable.compile_cell_c(cache, OGF_SPEC, OGF_BACKEND) + cell_d_merged.compile_cell_d(cache, CONFIG) + compile_lm_head(cache, CONFIG) + print("All compiled.\n") + + # ------ 2. Generate synthetic inputs ------ + n_layers = CONFIG["n_layers"] + weights_per_layer = [ + synthetic_layer_weights(L, CONFIG, SEED) for L in range(n_layers) + ] + lut = synthetic_lut(CONFIG, SEED) + x_in = synthetic_x_in(CONFIG, SEED) # token embedding entering layer 0 + attn_out_synth = synthetic_attn_out(CONFIG, SEED) # for golden validation only + + # Synthetic LM head partitions + rng = np.random.default_rng(SEED + 6666) + lm_weight_parts = [ + (rng.standard_normal((_LM_N_PART, CONFIG["emb_dim"])) * 0.02).astype(bfloat16) + for _ in range(_LM_N_PARTITIONS) + ] + final_norm_w = rng.standard_normal(CONFIG["emb_dim"]).astype(bfloat16) + + # ------ 3. Per-cell weight prep helpers (called inside per-cell loop) ------ + + rg_weights_per_layer = [ + {k: w[k] for k in ("norm_w", "wq", "wk", "wv")} for w in weights_per_layer + ] + for d in rg_weights_per_layer: + d["lut_q"] = lut["lut_q"] + d["lut_k"] = lut["lut_k"] + + of_weights_per_layer = [ + {k: w[k] for k in ("wo", "ffn_norm_w", "w_gate", "w_up", "w_down")} + for w in weights_per_layer + ] + + def _preload_for_cell(cell): + """Preload BOs for the given cell. Cell A doesn't preload (naive=True).""" + if cell == "B": + cell_b_static.preload_cell_b( + cache, RGR_SPEC, rg_weights_per_layer, CONFIG, RGR_BACKEND + ) + cell_b_static.preload_cell_b( + cache, OGF_SPEC, of_weights_per_layer, CONFIG, OGF_BACKEND + ) + elif cell == "C": + cell_c_charitable.preload_cell_c( + cache, RGR_SPEC, rg_weights_per_layer, CONFIG, RGR_BACKEND + ) + cell_c_charitable.preload_cell_c( + cache, OGF_SPEC, of_weights_per_layer, CONFIG, OGF_BACKEND + ) + elif cell == "D": + cell_d_merged.preload_cell_d( + cache, weights_per_layer, lut["lut_q"], lut["lut_k"], CONFIG + ) + # LM head invariant — preload for every cell (held INVARIANT in ablation) + preload_lm_head(cache, lm_weight_parts, CONFIG) + + def _unload_all_contexts(): + """Free up all NPU HW context slots and drop cached BOs. + + The NPU HW context limit is ~16. Cells A/B/C each load 14 standalone + ELFs + 1 LM head = 15 contexts; switching cells without unloading + would exceed the limit. We unload after each cell finishes its trials + so the next cell starts with a clean slot table. + """ + for name, (backend, _) in list(cache._loaded.items()): + try: + backend.unload() + except Exception: + pass + cache._loaded.clear() + cache._cached_bos.clear() + + # ------ 4. Run each cell: preload + validate + 5 trials + unload ------ + results = { + "config": CONFIG, + "current_pos": CURRENT_POS, + "prompt_len": PROMPT_LEN, + "trials": args.trials, + "cells": {}, + } + + for cell in ["A", "B", "C", "D"]: + print(f"=== Cell {cell}: preload + validate + {args.trials} trials ===") + _preload_for_cell(cell) + # Validate against goldens (single layer 0 run) + try: + _validate_cell( + cell, + cache, + weights_per_layer[0], + lut, + x_in, + attn_out_synth, + ) + validation = "PASS" + print(f" Cell {cell}: VALIDATION PASS") + except GoldenMismatch as e: + validation = f"FAIL: {e}" + print(f" Cell {cell}: VALIDATION FAIL — {e}") + results["cells"][cell] = {"validation": validation} + continue + + # Build per-layer inputs for the per_token_loop + layer_inputs_per_layer = [] + for L in range(n_layers): + li = { + "norm_w": weights_per_layer[L]["norm_w"], + "wq": weights_per_layer[L]["wq"], + "wk": weights_per_layer[L]["wk"], + "wv": weights_per_layer[L]["wv"], + "wo": weights_per_layer[L]["wo"], + "ffn_norm_w": weights_per_layer[L]["ffn_norm_w"], + "w_gate": weights_per_layer[L]["w_gate"], + "w_up": weights_per_layer[L]["w_up"], + "w_down": weights_per_layer[L]["w_down"], + "lut_q": lut["lut_q"], + "lut_k": lut["lut_k"], + } + layer_inputs_per_layer.append(li) + + # Build the cell-specific runners + rg_runner = _wrap_rg_runner(cell, RGR_SPEC) + of_runner = _wrap_of_runner(cell, OGF_SPEC) + + # 5 timed trials + trial_results = [] + for trial in range(args.trials): + # Reset KV cache to a fresh pre-filled state + kv_cache = build_initial_kv_cache(CONFIG, prompt_len=PROMPT_LEN, seed=SEED) + # Reset position CURRENT_POS so subsequent trials don't carry over the + # previously-generated k/v at slot CURRENT_POS + reset_position(kv_cache, CURRENT_POS) + + out = run_one_decode_token( + cache=cache, + config=CONFIG, + kv_cache=kv_cache, + layer_inputs_per_layer=layer_inputs_per_layer, + final_norm_w=final_norm_w, + lm_weight_parts=lm_weight_parts, + initial_x_decode=x_in, + current_pos=CURRENT_POS, + run_rms_gemv_rope=rg_runner, + run_o_gemv_ffn=of_runner, + ) + trial_results.append(out) + print( + f" trial {trial+1}: total={out['total_wall_s']*1000:.2f}ms" + f" cpu_attn={out['cpu_attn_wall_s']*1000:.2f}ms" + f" lm_head={out['lm_head_wall_s']*1000:.2f}ms" + ) + + # Drop trial 1 (warmup), median + (min,max) of remaining + kept = trial_results[1:] + kept_total = sorted([t["total_wall_s"] for t in kept]) + median_total = kept_total[len(kept_total) // 2] + + # Per-kernel-group medians: median over (16 layers × 4 kept trials) of per-layer wall + rg_walls = [w for t in kept for w in t["per_layer_rms_gemv_rope_wall_s"]] + of_walls = [w for t in kept for w in t["per_layer_o_gemv_ffn_wall_s"]] + rg_walls_sorted = sorted(rg_walls) + of_walls_sorted = sorted(of_walls) + rg_median_per_call = rg_walls_sorted[len(rg_walls_sorted) // 2] + of_median_per_call = of_walls_sorted[len(of_walls_sorted) // 2] + + # CPU attention floor (median across kept trials) + cpu_walls = sorted([t["cpu_attn_wall_s"] for t in kept]) + lm_walls = sorted([t["lm_head_wall_s"] for t in kept]) + + cell_summary = { + "validation": validation, + "all_trials_total_s": [t["total_wall_s"] for t in trial_results], + "median_total_s": median_total, + "min_total_s": min([t["total_wall_s"] for t in kept]), + "max_total_s": max([t["total_wall_s"] for t in kept]), + "rms_gemv_rope_per_call_median_s": rg_median_per_call, + "o_gemv_ffn_per_call_median_s": of_median_per_call, + "cpu_attn_total_median_s": cpu_walls[len(cpu_walls) // 2], + "lm_head_median_s": lm_walls[len(lm_walls) // 2], + "next_token": trial_results[-1]["next_token"], + } + results["cells"][cell] = cell_summary + print( + f" Cell {cell} median total: {median_total*1000:.2f}ms " + f"rg/call: {rg_median_per_call*1000:.2f}ms " + f"of/call: {of_median_per_call*1000:.2f}ms" + ) + + # Free up NPU HW context slots before next cell loads its ELFs + _unload_all_contexts() + print(f" (unloaded contexts)\n") + + # ------ 5. Write results JSON ------ + out_path = args.out or os.path.join(_THIS, f"results_{int(time.time())}.json") + with open(out_path, "w") as f: + json.dump(results, f, indent=2) + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/decode/specs/__init__.py b/programming_examples/llama32_1b/ablation/decode/specs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/decode/specs/kernel_group.py b/programming_examples/llama32_1b/ablation/decode/specs/kernel_group.py new file mode 100644 index 000000000..3eb295c97 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/specs/kernel_group.py @@ -0,0 +1,14 @@ +"""Re-export Plan 1's KernelGroupSpec dataclasses (single source of truth). + +Decode specs (rms_gemv_rope, o_gemv_ffn) and cells reference these. Keeping +one definition prevents drift across the three plans. +""" + +from prefill.specs.kernel_group import ( + SubLaunchSpec, + BatonLink, + KernelGroupSpec, + validate_baton_links, +) + +__all__ = ["SubLaunchSpec", "BatonLink", "KernelGroupSpec", "validate_baton_links"] diff --git a/programming_examples/llama32_1b/ablation/decode/specs/o_gemv_ffn.py b/programming_examples/llama32_1b/ablation/decode/specs/o_gemv_ffn.py new file mode 100644 index 000000000..b5f5f5af6 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/specs/o_gemv_ffn.py @@ -0,0 +1,179 @@ +"""Concrete KernelGroupSpec for the decode o_gemv_ffn kernel-group. + +Mirrors the production stitch-spec in +multi_launch_builder/o_gemv_ffn_multi.py:308-482 (the 8-launch decode pipeline: +O GEMV + Add + RMSNorm + Gate GEMV + Up GEMV + SwiGLU + Down GEMV + Add). + +15 merged-func args (slots 0-14); weights at {0,5,7,9,12}; +intermediates at {2,4,6,8,10,11,13,14}. + +Slot conventions for standalones (CRITICAL — different from prefill GEMM): + - gemv: (W[out, in], x[in], y[out]) weight=0, out=2 (matvec convention) + - add_2d: (a[N,d], b[N,d], out[N,d]) no weight, out=2 + (called as N=emb_dim//8, d=emb_dim, herd_x=8) + - rms_1d: (x[emb], norm_w[emb], out[emb]) weight=1, out=2 + - swiglu: (gate[hidden], up[hidden], out[hidden]) no weight, out=2 + +Production decode shapes (single token): + emb_dim=2048, hidden_dim=8192, head_dim=64. + K=2048 GEMVs (O, Gate, Up): tile_m=8, m_input=4, herd_m=8 + K=8192 Down GEMV: tile_m=2, m_input=1, herd_m=8 + +Note on Down GEMV "renaming": + The PRODUCTION MERGED ELF renames Down GEMV's @matvec to + @dg_matvec_vectorized_bf16_bf16 + link_with="mv_k8192.o" because two GEMVs + with different signatures can't coexist in one ELF with the same C symbol. + STANDALONE down_gemv has no such conflict — it's its own ELF — so it uses + the standard @matvec_vectorized_bf16_bf16 + mv.o (compiled with default + tile_m). The MLIR loop structure uses tile_m=2, m_input=1 from build_gemv. +""" + +from ml_dtypes import bfloat16 + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + +# --------------------------------------------------------------------------- +# Sub-launch standalone builders +# --------------------------------------------------------------------------- + + +def _build_o_gemv_standalone(): + """O GEMV: (wo[2048,2048], attn_out[2048], proj[2048]).""" + from matvec import build_module as build_gemv + + return build_gemv(2048, 2048, 8, 4, 8, bfloat16, bfloat16) + + +def _build_add1_standalone(): + """Residual add #1 (post-attn): (proj[2048], x_residual[2048], res1[2048]). + + eltwise_add.build_module(M, N, ...) accepts 2D shape (M, N). Production + calls it with M=emb_dim, N=emb_dim//8, herd=[8,1] — so the 1D activation + is reshaped/tiled across M=emb_dim rows of N=emb_dim//8 cols. + + Wraps via _wrap_ir_in_launch (eltwise_add emits a bare herd). + """ + from eltwise_add.eltwise_add import build_module as build_add + from kernel_builder.stitching import _wrap_ir_in_launch + from air.ir import Module + + bare = str(build_add(2048, 2048 // 8, bfloat16, vector_size=16, herd_x=8, herd_y=1)) + return Module.parse(_wrap_ir_in_launch(bare)) + + +def _build_rmsnorm_standalone(): + """1D RMSNorm: (res1[2048], ffn_norm_w[2048], normed2[2048]). + + Imports _build_rms_1d_ir from o_gemv_ffn_multi (returns MLIR text) + and parses to a Module. This is the SAME 1D RMSNorm wrapper used by + the production merged ELF, so byte-equality is guaranteed. + """ + from multi_launch_builder.o_gemv_ffn_multi import _build_rms_1d_ir + from air.ir import Module + + return Module.parse(_build_rms_1d_ir(2048, vector_size=16)) + + +def _build_gate_or_up_gemv_standalone(): + """Gate or Up GEMV: (w[8192,2048], normed2[2048], out[8192]).""" + from matvec import build_module as build_gemv + + return build_gemv(8192, 2048, 8, 4, 8, bfloat16, bfloat16) + + +def _build_swiglu_standalone(): + """SwiGLU: (gate[8192], up[8192], swiglu[8192]). + + Uses kernel_builder.ffn_swiglu.silu_and_mul.build_module (1D variant). + Wraps via _wrap_ir_in_launch (silu emits a bare herd). + """ + from kernel_builder.ffn_swiglu.silu_and_mul import build_module as build_silu + from kernel_builder.stitching import _wrap_ir_in_launch + from air.ir import Module + + bare = str(build_silu(8192, 8192 // 8, bfloat16, herd_x=8, herd_y=1)) + return Module.parse(_wrap_ir_in_launch(bare)) + + +def _build_down_gemv_standalone(): + """Down GEMV: (wdown[2048,8192], swiglu[8192], down[2048]). + + Smaller tiles: tile_m=2, m_input=1 (production uses these for K=8192). + As a STANDALONE, uses the default mv.o — no rename needed (only the + merged ELF needs the rename to avoid C-symbol collision with K=2048 + GEMVs). + """ + from matvec import build_module as build_gemv + + return build_gemv(2048, 8192, 2, 1, 8, bfloat16, bfloat16) + + +def _build_add2_standalone(): + """Residual add #2 (post-FFN): (down[2048], res1[2048], output[2048]). + + Same builder as _build_add1_standalone — production uses the SAME + config (M=emb_dim, N=emb_dim//8, herd=[8,1]) for both residual adds. + """ + return _build_add1_standalone() + + +# --------------------------------------------------------------------------- +# KernelGroupSpec +# --------------------------------------------------------------------------- + +SPEC = KernelGroupSpec( + name="o_gemv_ffn", + sub_launches=( + # idx=0: O GEMV — slot 0=W (wo), slot 1=x (attn_out), slot 2=y (proj) + SubLaunchSpec("o_gemv", _build_o_gemv_standalone, {}, 0, 2), + # idx=1: Add (post-attn residual) — no weight, slot 0=A, 1=B, 2=res1 + SubLaunchSpec("add_attn_residual", _build_add1_standalone, {}, None, 2), + # idx=2: FFN RMSNorm — slot 0=x (res1), 1=norm_w, 2=normed2 + SubLaunchSpec("ffn_rmsnorm", _build_rmsnorm_standalone, {}, 1, 2), + # idx=3: Gate GEMV — slot 0=W (wgate), 1=x (normed2), 2=y (gate) + SubLaunchSpec("gate_gemv", _build_gate_or_up_gemv_standalone, {}, 0, 2), + # idx=4: Up GEMV — slot 0=W (wup), 1=x (normed2), 2=y (up) + SubLaunchSpec("up_gemv", _build_gate_or_up_gemv_standalone, {}, 0, 2), + # idx=5: SwiGLU — no weight, slot 0=gate, 1=up, 2=swiglu + SubLaunchSpec("swiglu", _build_swiglu_standalone, {}, None, 2), + # idx=6: Down GEMV — slot 0=W (wdown), 1=x (swiglu), 2=y (down) + SubLaunchSpec("down_gemv_k8192", _build_down_gemv_standalone, {}, 0, 2), + # idx=7: Add (FFN residual) — no weight, slot 0=A (down), 1=B (res1), 2=output + SubLaunchSpec("add_ffn_residual", _build_add2_standalone, {}, None, 2), + ), + merged_arg_signature=( + "wo", # 0 weight (static) + "attn_out", # 1 activation input + "proj", # 2 intermediate + "x_residual", # 3 activation input + "res1", # 4 intermediate (shared: add1 out + add2 B) + "ffn_norm_w", # 5 weight (static) + "normed2", # 6 intermediate + "wgate", # 7 weight (static) + "gate", # 8 intermediate + "wup", # 9 weight (static) + "up", # 10 intermediate + "swiglu", # 11 intermediate + "wdown", # 12 weight (static) + "down", # 13 intermediate + "output", # 14 intermediate (final output) + ), + weight_slots=frozenset({0, 5, 7, 9, 12}), + intermediate_slots=frozenset({2, 4, 6, 8, 10, 11, 13, 14}), + output_slots_for_validation=(14,), + baton_links=( + # Stitch arg_map verified against o_gemv_ffn_multi.py lines 394-403: + # L1 {0:0,1:1,2:2} L2 {0:2,1:3,2:4} L3 {0:4,1:5,2:6} + # L4 {0:7,1:6,2:8} L5 {0:9,1:6,2:10} L6 {0:8,1:10,2:11} + # L7 {0:12,1:11,2:13} L8 {0:13,1:4,2:14} + BatonLink(0, 2, 1, 0), # o_gemv.proj -> add_attn.A + BatonLink(1, 2, 2, 0), # add_attn.res1 -> ffn_rmsnorm.x + BatonLink(2, 2, 3, 1), # ffn_rmsnorm.normed2 -> gate_gemv.x (slot 1!) + BatonLink(2, 2, 4, 1), # ffn_rmsnorm.normed2 -> up_gemv.x (slot 1!) + BatonLink(3, 2, 5, 0), # gate_gemv.gate -> swiglu.gate + BatonLink(4, 2, 5, 1), # up_gemv.up -> swiglu.up + BatonLink(5, 2, 6, 1), # swiglu -> down_gemv.x (slot 1!) + BatonLink(6, 2, 7, 0), # down_gemv.down -> add_ffn.A + BatonLink(1, 2, 7, 1), # add_attn.res1 -> add_ffn.B (residual-of-residual) + ), +) diff --git a/programming_examples/llama32_1b/ablation/decode/specs/rms_gemv_rope.py b/programming_examples/llama32_1b/ablation/decode/specs/rms_gemv_rope.py new file mode 100644 index 000000000..64fd203f2 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/specs/rms_gemv_rope.py @@ -0,0 +1,86 @@ +"""Concrete KernelGroupSpec for the decode rms_gemv_rope kernel-group. + +Mirrors the production stitch-spec in +multi_launch_builder/rms_gemv_rope_multi.py (the 6-launch decode pipeline: +RMSNorm + Q/K/V GEMV + RoPE Q + RoPE K). + +Slot conventions for standalones: + - rmsnorm: (x_in[emb], norm_w[emb], out[emb]) weight=1, out=2 + - gemv: (W[out, in], x[in], y[out]) weight=0, out=2 + (matvec convention — W is at slot 0, NOT slot 1 like prefill GEMM.) + - rope: (in_flat[N], lut[head_dim], out_flat[N]) weight=1 (LUT), out=2 + +Production decode shapes (single token): + emb_dim=2048, kv_dim=512, n_heads=32, n_kv_heads=8, head_dim=64. + q_total = n_heads * head_dim = 2048 (= emb_dim by construction) + k_total = n_kv_heads * head_dim = 512 (= kv_dim by construction) +""" + +from standalone_builders.rms_gemv_rope import STANDALONES as _PLAN0_STANDALONES +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + +# Plan 0's STANDALONES is a list of (name, build_fn, build_kwargs) tuples. +# Convert to a name→(build_fn, build_kwargs) lookup for SubLaunchSpec construction. +_BUILDERS = {name: (build_fn, kwargs) for name, build_fn, kwargs in _PLAN0_STANDALONES} + + +def _b(name): + """Helper: extract (build_fn, build_kwargs) for a sub-launch by name.""" + return _BUILDERS[name] + + +SPEC = KernelGroupSpec( + name="rms_gemv_rope", + sub_launches=( + # idx=0: RMSNorm — slot 0=x_in, slot 1=norm_w (weight), slot 2=normed (out) + SubLaunchSpec("rmsnorm", _b("rmsnorm")[0], _b("rmsnorm")[1], 1, 2), + # idx=1: Q GEMV — slot 0=W (wq), slot 1=x (normed), slot 2=y (q) + SubLaunchSpec("q_gemv", _b("q_gemv")[0], _b("q_gemv")[1], 0, 2), + # idx=2: K GEMV — slot 0=W (wk), slot 1=x, slot 2=y (k) + SubLaunchSpec("k_gemv", _b("k_gemv")[0], _b("k_gemv")[1], 0, 2), + # idx=3: V GEMV — slot 0=W (wv), slot 1=x, slot 2=y (v) + SubLaunchSpec("v_gemv", _b("v_gemv")[0], _b("v_gemv")[1], 0, 2), + # idx=4: RoPE Q — slot 0=in (q), slot 1=lut_q (weight), slot 2=out (q_roped) + SubLaunchSpec("rope_q", _b("rope_q")[0], _b("rope_q")[1], 1, 2), + # idx=5: RoPE K — slot 0=in (k), slot 1=lut_k, slot 2=out (k_roped) + SubLaunchSpec("rope_k", _b("rope_k")[0], _b("rope_k")[1], 1, 2), + ), + merged_arg_signature=( + "x_in", # 0 activation input + "norm_w", # 1 weight (static) + "normed", # 2 intermediate + "wq", # 3 weight (static) + "q", # 4 intermediate + "wk", # 5 weight (static) + "k", # 6 intermediate + "wv", # 7 weight (static) + "v", # 8 intermediate + "lut_q", # 9 weight (static) + "lut_k", # 10 weight (static) + "q_roped", # 11 intermediate (also output for validation) + "k_roped", # 12 intermediate (also output for validation) + ), + weight_slots=frozenset({1, 3, 5, 7, 9, 10}), + intermediate_slots=frozenset({2, 4, 6, 8, 11, 12}), + output_slots_for_validation=(2, 4, 6, 8, 11, 12), + baton_links=( + # rmsnorm.normed (slot 2) -> q/k/v_gemv.x (slot 1 — matvec convention!) + BatonLink( + producer_idx=0, producer_out_slot=2, consumer_idx=1, consumer_in_slot=1 + ), + BatonLink( + producer_idx=0, producer_out_slot=2, consumer_idx=2, consumer_in_slot=1 + ), + BatonLink( + producer_idx=0, producer_out_slot=2, consumer_idx=3, consumer_in_slot=1 + ), + # q_gemv.q (slot 2) -> rope_q.in (slot 0) + BatonLink( + producer_idx=1, producer_out_slot=2, consumer_idx=4, consumer_in_slot=0 + ), + # k_gemv.k (slot 2) -> rope_k.in (slot 0) + BatonLink( + producer_idx=2, producer_out_slot=2, consumer_idx=5, consumer_in_slot=0 + ), + ), +) diff --git a/programming_examples/llama32_1b/ablation/decode/standalone_builders/__init__.py b/programming_examples/llama32_1b/ablation/decode/standalone_builders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/decode/standalone_builders/o_gemv_ffn.py b/programming_examples/llama32_1b/ablation/decode/standalone_builders/o_gemv_ffn.py new file mode 100644 index 000000000..80b58448e --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/standalone_builders/o_gemv_ffn.py @@ -0,0 +1,12 @@ +"""Single-launch standalone modules for the decode o_gemv_ffn kernel-group. + +Exports a STANDALONES registry compatible with cells/common.py:compile_standalone_kernels. +The actual builder functions live in specs/o_gemv_ffn.py (alongside the SPEC); this +module is a thin derived registry that converts SPEC.sub_launches → list of tuples. +""" + +from specs.o_gemv_ffn import SPEC + +STANDALONES = [ + (sub.name, sub.builder_ref, sub.build_kwargs) for sub in SPEC.sub_launches +] diff --git a/programming_examples/llama32_1b/ablation/decode/standalone_builders/rms_gemv_rope.py b/programming_examples/llama32_1b/ablation/decode/standalone_builders/rms_gemv_rope.py new file mode 100644 index 000000000..479403abd --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/standalone_builders/rms_gemv_rope.py @@ -0,0 +1,55 @@ +"""Single-launch standalone MLIR modules for the decode rms_gemv_rope kernel-group. + +Each function returns a ready-to-compile mlir.Module containing exactly one +air.launch (or launch+segment for sub-builders that emit bare herds) at +production decode shape (single-token, emb_dim=2048, kv_dim=512, +n_heads=32, n_kv_heads=8, head_dim=64). + +These are the Cell-A/B/C inputs. Cell D reuses the production merged +build_rms_gemv_rope_module from multi_launch_builder/rms_gemv_rope_multi.py. + +The 6 sub-launches mirror the production stitch-spec in +multi_launch_builder/rms_gemv_rope_multi.py. +""" + +from ml_dtypes import bfloat16 + +from multi_launch_builder.rms_gemv_rope_multi import ( + _build_rms_1d, + _build_rope_1d, +) + + +def build_rmsnorm(emb_dim=2048): + """RMSNorm 1D: (x_in[emb_dim], norm_w[emb_dim]) -> normed[emb_dim].""" + return _build_rms_1d(emb_dim, bfloat16, 16) + + +def build_gemv(out_dim, in_dim, tile_m=8, m_input=4, herd_m=8): + """Generic decode GEMV: (W[out_dim, in_dim], x[in_dim]) -> y[out_dim]. + + Covers Q (out=emb=2048), K/V (out=kv=512). + """ + from matvec import build_module as _build_gemv + + return _build_gemv(out_dim, in_dim, tile_m, m_input, herd_m, bfloat16, bfloat16) + + +def build_rope(n_rows, head_dim=64, herd_x=1): + """RoPE 1D: (x_flat[n_rows*head_dim], lut[head_dim]) -> y_flat[n_rows*head_dim]. + + Covers RoPE Q (n_rows=n_heads=32) and RoPE K (n_rows=n_kv_heads=8). + """ + return _build_rope_1d(n_rows, head_dim, bfloat16, herd_x) + + +# Full registry of standalones for this kernel-group. +# Each entry: (name, build_fn, build_kwargs) +STANDALONES = [ + ("rmsnorm", build_rmsnorm, {"emb_dim": 2048}), + ("q_gemv", build_gemv, {"out_dim": 2048, "in_dim": 2048}), + ("k_gemv", build_gemv, {"out_dim": 512, "in_dim": 2048}), + ("v_gemv", build_gemv, {"out_dim": 512, "in_dim": 2048}), + ("rope_q", build_rope, {"n_rows": 32, "head_dim": 64}), + ("rope_k", build_rope, {"n_rows": 8, "head_dim": 64}), +] diff --git a/programming_examples/llama32_1b/ablation/decode/tests/__init__.py b/programming_examples/llama32_1b/ablation/decode/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/decode/tests/conftest.py b/programming_examples/llama32_1b/ablation/decode/tests/conftest.py new file mode 100644 index 000000000..a671f3ed4 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/tests/conftest.py @@ -0,0 +1,47 @@ +"""Pytest config for full-decode ablation tests. + +Inserts paths so tests can import: +- llama32_1b/ packages (kernel_builder, multi_launch_builder) +- llama32_1b/ablation/ (Plan 0's standalone_builders + validate.py) +- llama32_1b/ablation/prefill/ (Plan 1's cells, specs, common helpers) +- llama32_1b/ablation/decode/ (this package) +- programming_examples/ (matvec, weighted_rms_norm, ffn_swiglu) +""" + +import os +import sys + +_THIS = os.path.dirname(os.path.abspath(__file__)) +_DECODE = os.path.dirname(_THIS) +_ABLATION = os.path.dirname(_DECODE) +_LLAMA = os.path.dirname(_ABLATION) +_PROG_EXAMPLES = os.path.dirname(_LLAMA) + +for p in ( + _PROG_EXAMPLES, + _LLAMA, + _ABLATION, + os.path.join(_ABLATION, "prefill"), + _DECODE, +): + if p not in sys.path: + sys.path.insert(0, p) + +# Pytest may have already inserted other paths or pre-imported a `cells` package +# from prefill/. Force _DECODE to sys.path[0] AND drop any cached `cells*` modules +# so subsequent `from cells.X import Y` resolves to decode/cells/. +if sys.path[0] != _DECODE: + if _DECODE in sys.path: + sys.path.remove(_DECODE) + sys.path.insert(0, _DECODE) + +for _stale in [m for m in list(sys.modules) if m == "cells" or m.startswith("cells.")]: + del sys.modules[_stale] +for _stale in [m for m in list(sys.modules) if m == "specs" or m.startswith("specs.")]: + del sys.modules[_stale] +for _stale in [ + m + for m in list(sys.modules) + if m == "standalone_builders" or m.startswith("standalone_builders.") +]: + del sys.modules[_stale] diff --git a/programming_examples/llama32_1b/ablation/decode/tests/test_kv_cache_state.py b/programming_examples/llama32_1b/ablation/decode/tests/test_kv_cache_state.py new file mode 100644 index 000000000..d2036b86c --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/tests/test_kv_cache_state.py @@ -0,0 +1,58 @@ +"""KV cache state must be deterministic and per-trial resettable.""" + +import numpy as np + +from cells.kv_cache import build_initial_kv_cache, reset_position + +CONFIG = { + "n_layers": 16, + "n_kv_heads": 8, + "head_dim": 64, + "max_seq": 2048, +} + + +def test_initial_cache_is_deterministic(): + c1 = build_initial_kv_cache(CONFIG, prompt_len=7, seed=42) + c2 = build_initial_kv_cache(CONFIG, prompt_len=7, seed=42) + assert c1["k_cache"].tobytes() == c2["k_cache"].tobytes() + assert c1["v_cache"].tobytes() == c2["v_cache"].tobytes() + assert c1["current_pos"] == 7 + assert c2["current_pos"] == 7 + + +def test_initial_cache_zeros_after_prompt_len(): + cache = build_initial_kv_cache(CONFIG, prompt_len=7, seed=42) + # Positions 7..max_seq-1 must be zeros + after = cache["k_cache"][:, :, 7:, :] + assert np.all(after.view(np.uint8) == 0) + after_v = cache["v_cache"][:, :, 7:, :] + assert np.all(after_v.view(np.uint8) == 0) + + +def test_initial_cache_nonzero_in_prompt_range(): + cache = build_initial_kv_cache(CONFIG, prompt_len=7, seed=42) + # At least some entries in [0:7] must be non-zero + pre = cache["k_cache"][:, :, :7, :] + assert not np.all(pre.view(np.uint8) == 0) + + +def test_reset_position_zeros_target_slot_only(): + cache = build_initial_kv_cache(CONFIG, prompt_len=7, seed=42) + # Simulate a kernel writing to position 7 in layer 0 + cache["k_cache"][0, :, 7, :] = 99.0 + cache["v_cache"][0, :, 7, :] = -42.0 + # Reset should zero position 7 across ALL layers + reset_position(cache, 7) + assert np.all(cache["k_cache"][:, :, 7, :].view(np.uint8) == 0) + assert np.all(cache["v_cache"][:, :, 7, :].view(np.uint8) == 0) + # Positions 0..6 must be untouched (still match a fresh init) + fresh = build_initial_kv_cache(CONFIG, prompt_len=7, seed=42) + assert ( + cache["k_cache"][:, :, :7, :].tobytes() + == fresh["k_cache"][:, :, :7, :].tobytes() + ) + assert ( + cache["v_cache"][:, :, :7, :].tobytes() + == fresh["v_cache"][:, :, :7, :].tobytes() + ) diff --git a/programming_examples/llama32_1b/ablation/decode/tests/test_validation_gate.py b/programming_examples/llama32_1b/ablation/decode/tests/test_validation_gate.py new file mode 100644 index 000000000..a561375cf --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/tests/test_validation_gate.py @@ -0,0 +1,62 @@ +"""Verify Plan 1's validate.py works against the new decode goldens. + +Two goldens: golden_rms_gemv_rope_decode.npz and golden_o_gemv_ffn_decode.npz. +For each, two tests: + 1. Loading the golden and validating it against itself MUST pass. + 2. Mutating one byte and re-validating MUST raise GoldenMismatch. + +These tests do NOT touch the NPU. +""" + +import os + +import numpy as np +from ml_dtypes import bfloat16 + +from validate import GoldenMismatch, validate_against_golden + +GOLDEN_DIR = os.path.join(os.path.dirname(__file__), "..", "golden") + + +def _load(name): + return np.load(os.path.join(GOLDEN_DIR, name)) + + +def test_rms_gemv_rope_passes_on_exact_match(): + npz = _load("golden_rms_gemv_rope_decode.npz") + cell_outputs = {key: npz[key] for key in npz.files} + validate_against_golden(cell_outputs, GOLDEN_DIR, "golden_rms_gemv_rope_decode.npz") + + +def test_rms_gemv_rope_raises_on_byte_diff(): + npz = _load("golden_rms_gemv_rope_decode.npz") + perturbed = {k: npz[k].copy() for k in npz.files} + arr = perturbed["normed"].view(np.uint8).copy() + arr[0] ^= 0x01 # flip one bit + perturbed["normed"] = arr.view(bfloat16).reshape(npz["normed"].shape) + try: + validate_against_golden( + perturbed, GOLDEN_DIR, "golden_rms_gemv_rope_decode.npz" + ) + raise AssertionError("expected GoldenMismatch") + except GoldenMismatch: + pass + + +def test_o_gemv_ffn_passes_on_exact_match(): + npz = _load("golden_o_gemv_ffn_decode.npz") + cell_outputs = {key: npz[key] for key in npz.files} + validate_against_golden(cell_outputs, GOLDEN_DIR, "golden_o_gemv_ffn_decode.npz") + + +def test_o_gemv_ffn_raises_on_byte_diff(): + npz = _load("golden_o_gemv_ffn_decode.npz") + perturbed = {k: npz[k].copy() for k in npz.files} + arr = perturbed["output"].view(np.uint8).copy() + arr[0] ^= 0x01 + perturbed["output"] = arr.view(bfloat16).reshape(npz["output"].shape) + try: + validate_against_golden(perturbed, GOLDEN_DIR, "golden_o_gemv_ffn_decode.npz") + raise AssertionError("expected GoldenMismatch") + except GoldenMismatch: + pass diff --git a/programming_examples/llama32_1b/ablation/decode/validate.py b/programming_examples/llama32_1b/ablation/decode/validate.py new file mode 100644 index 000000000..46fd1f365 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/decode/validate.py @@ -0,0 +1,12 @@ +"""Re-export Plan 1's parameterized bit-exact validation gate. + +Plan 1's validate.py accepts a `golden_filename` parameter, so the same +function works for decode goldens too — just pass a different filename. +""" + +from prefill.validate import ( + validate_against_golden, + GoldenMismatch, +) + +__all__ = ["validate_against_golden", "GoldenMismatch"] diff --git a/programming_examples/llama32_1b/ablation/docs/plans/2026-05-07-llama32-1b-ablation-plan2-prefill.md b/programming_examples/llama32_1b/ablation/docs/plans/2026-05-07-llama32-1b-ablation-plan2-prefill.md new file mode 100644 index 000000000..4fe337914 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/docs/plans/2026-05-07-llama32-1b-ablation-plan2-prefill.md @@ -0,0 +1,2611 @@ +# Llama-3.2-1B Plan 2 (Prefill) Ablation Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build the 4-cell ablation ladder for the **prefill** kernel-groups (`rms_gemms_rope` 6 launches + `o_ffn` 8 launches at seq=2048 GEMM shapes) using parameterized cells driven by `KernelGroupSpec` dataclasses. FA held constant per master spec. Single-layer + 16-layer scopes. Bit-exact validation against committed goldens. Headline number directly comparable to `profile.md`'s 1.27 s prefill. + +**Architecture:** Self-contained subdir `programming_examples/llama32_1b/ablation/prefill/` (Plan 1 files at top-level remain byte-immutable). 4 parameterized cell modules walk a `KernelGroupSpec` (one spec per kernel-group) describing sub-launches, slot semantics, and baton-pass topology. A 16-layer wrapper threads `o_ffn.output[L] → rms_gemms_rope.x_in[L+1]` with FA invariant between the two intra-layer kernel-groups. Reuses Plan 1's `KernelCache.naive=True`, `cells/common.py:compile_standalone_kernels` (helper extracted to `prefill/cells/common.py` and parameterized), and `validate.py` (verbatim, kernel-group-agnostic). + +**Tech Stack:** Python 3, numpy, ml_dtypes (bfloat16), pytest, mlir-air's `XRTBackend` + `KernelCache` + existing sub-builders (`build_rms_gemms_rope_module`, `build_o_ffn_module` from `multi_launch_builder/`; `_build_gemm_module` from `kernel_builder/gemm_builder.py`; `_build_rope_2d` from `multi_launch_builder/rms_gemms_rope_multi.py:63`; `_build_add_2d_to_2d` from `multi_launch_builder/o_ffn_multi.py`; `weighted_rms_norm.weighted_rms_norm.build_module`; `ffn_swiglu.silu_and_mul`). + +**Companion docs:** +- Plan 2 spec: `programming_examples/llama32_1b/ablation/docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md` +- Master ablation spec: removed from repo (decode pilot deleted; superseded by full-decode study at `ablation/docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md`) +- Plan 1 (decode pilot) plan: removed from repo (subsumed by full-decode study at `ablation/docs/plans/2026-05-12-llama32-1b-ablation-plan2-fulldecode-plan.md`) +- Plan 1's working code at `programming_examples/llama32_1b/ablation/` — removed; see `ablation/decode/` for the superseding study. + +--- + +## File Structure + +All paths under `programming_examples/llama32_1b/ablation/prefill/` unless noted. + +| File | Responsibility | +|---|---| +| `__init__.py` | Package marker | +| `README.md` | Methodology, run instructions, results, reproducibility | +| `Makefile` | `make compile / regen-golden / run / report / all / clean` | +| `specs/__init__.py` | Package marker | +| `specs/kernel_group.py` | Frozen dataclasses: `SubLaunchSpec`, `BatonLink`, `KernelGroupSpec` | +| `specs/rms_gemms_rope.py` | Concrete `KernelGroupSpec` instance for the 6-launch prefill attention pre-block | +| `specs/o_ffn.py` | Concrete `KernelGroupSpec` instance for the 8-launch prefill FFN block | +| `standalone_builders/__init__.py` | Package marker | +| `standalone_builders/rms_gemms_rope.py` | 6 single-launch builder wrappers + `STANDALONES` registry | +| `standalone_builders/o_ffn.py` | 8 single-launch builder wrappers + `STANDALONES` registry | +| `cells/__init__.py` | Package marker | +| `cells/common.py` | `compile_standalone_kernels` (parameterized), `_extract_public_func_name`, `_share_bo`, `standalone_backend_kwargs` helpers | +| `cells/cell_a_naive.py` | Parameterized Cell A — walks a `KernelGroupSpec` with `naive=True` | +| `cells/cell_b_static.py` | Parameterized Cell B — preload weights, then `static_input_indices` | +| `cells/cell_c_charitable.py` | Parameterized Cell C — preload + alias intermediate BOs per `spec.baton_links` | +| `cells/cell_d_merged.py` | Wrapper around production `build_rms_gemms_rope_module` and `build_o_ffn_module` | +| `cells/flash_attn_const.py` | FA invariant: compile + invoke production FA ELF identically across all cells | +| `cells/multi_layer.py` | Wraps a per-layer triple (rms_gemms_rope → FA → o_ffn) in a 16-layer loop | +| `golden/__init__.py` | Package marker | +| `golden/regen_golden.py` | One-shot Cell-D run for layer 0; dumps two npz fixtures + meta json | +| `golden/golden_rms_gemms_rope_prefill.npz` | Committed bit-exact reference (Cell D's 6 outputs, layer 0, seed=42) | +| `golden/golden_o_ffn_prefill.npz` | Committed bit-exact reference (Cell D's relevant outputs for o_ffn, layer 0, seed=42) | +| `golden/golden_meta.json` | Hashes, shapes, config | +| `run_ablation.py` | Orchestrator: validate → time × {single-layer, 16-layer} × 4 cells, emit JSON | +| `analyze.py` | JSON → markdown report | +| `tests/__init__.py` | Package marker | +| `tests/conftest.py` | Pytest sys.path setup | +| `tests/test_kernel_group_spec.py` | Dataclass invariants (NPU-free) | +| `tests/test_parameterized_cells.py` | Mock-cache tests verifying each cell walks its spec correctly (NPU-free) | +| `tests/test_validation_gate.py` | Imports Plan 1's `validate.py` and tests it against new prefill goldens | + +**Files NOT touched (Plan 1 isolation guarantee):** every file under `programming_examples/llama32_1b/ablation/` outside `prefill/`. Production code (`programming_examples/llama32_1b/kernel_builder/`, `multi_launch_builder/`) read-only — only imported. + +--- + +## Phase 1 — Skeleton + Specs (Tasks 1–4) + +## Task 1: Subdir skeleton + pytest conftest + +**Files:** +- Create: 9 `__init__.py` files (one per package directory) +- Create: `programming_examples/llama32_1b/ablation/prefill/tests/conftest.py` + +- [ ] **Step 1: Create empty package markers** + +```bash +mkdir -p programming_examples/llama32_1b/ablation/prefill/{specs,standalone_builders,cells,golden,tests} +for d in prefill prefill/specs prefill/standalone_builders prefill/cells prefill/golden prefill/tests; do + touch programming_examples/llama32_1b/ablation/$d/__init__.py +done +``` + +- [ ] **Step 2: Write conftest.py** + +`programming_examples/llama32_1b/ablation/prefill/tests/conftest.py`: + +```python +"""Pytest config for prefill ablation tests. + +Inserts paths so tests can import: +- llama32_1b/ packages (kernel_builder, multi_launch_builder) +- llama32_1b/ablation/ (Plan 1's validate.py and shared helpers) +- llama32_1b/ablation/prefill/ (this package) +- programming_examples/ (matvec, weighted_rms_norm, ffn_swiglu) +""" + +import os +import sys + +_THIS = os.path.dirname(os.path.abspath(__file__)) +_PREFILL = os.path.dirname(_THIS) +_ABLATION = os.path.dirname(_PREFILL) +_LLAMA = os.path.dirname(_ABLATION) +_PROG_EXAMPLES = os.path.dirname(_LLAMA) + +for p in (_PROG_EXAMPLES, _LLAMA, _ABLATION, _PREFILL): + if p not in sys.path: + sys.path.insert(0, p) + +# Pytest's package-import mode inserts the package parent (ablation/) into sys.path[0] +# before this conftest runs, which can shadow prefill/validate.py with ablation/validate.py. +# Guarantee that prefill/ is at index 0 so prefill-local modules take priority. +if sys.path[0] != _PREFILL: + sys.path.remove(_PREFILL) if _PREFILL in sys.path else None + sys.path.insert(0, _PREFILL) +``` + +> **Implementation note (T10 wash-up):** The final three lines above were added in T10 +> to fix pytest's package-import mode inserting `ablation/` at `sys.path[0]` before the +> conftest ran, shadowing `prefill/validate.py` with `ablation/validate.py`. The fix +> always-removes-then-reinserts `_PREFILL` at index 0 after the initial insertion loop. + +- [ ] **Step 3: Verify pytest discovers the empty test dir** + +Run: `cd programming_examples/llama32_1b/ablation/prefill && python3 -m pytest tests/ -v` +Expected: `no tests ran in 0.0Xs` (zero tests, zero errors). + +- [ ] **Step 4: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/ +git commit -m "ablation/prefill: scaffold subdir skeleton with pytest conftest" +``` + +--- + +## Task 2: Spec dataclasses (`SubLaunchSpec`, `BatonLink`, `KernelGroupSpec`) + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/specs/kernel_group.py` +- Test: `programming_examples/llama32_1b/ablation/prefill/tests/test_kernel_group_spec.py` + +- [ ] **Step 1: Write the failing test** + +`prefill/tests/test_kernel_group_spec.py`: + +```python +"""Unit tests for the KernelGroupSpec dataclasses.""" + +import pytest +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + + +def _dummy_builder(): + return None # Spec test doesn't need a real builder + + +def test_sublaunch_spec_is_frozen(): + s = SubLaunchSpec( + name="rms", + builder_ref=_dummy_builder, + build_kwargs={"emb_dim": 2048}, + weight_slot_in_standalone=1, + output_slot_in_standalone=2, + ) + with pytest.raises((AttributeError, TypeError)): # frozen + s.name = "other" + + +def test_baton_link_orders_by_indices(): + link = BatonLink(producer_idx=0, producer_out_slot=2, + consumer_idx=1, consumer_in_slot=1) + assert link.consumer_idx > link.producer_idx + + +def test_kernel_group_spec_holds_sublaunches(): + sub = SubLaunchSpec("rms", _dummy_builder, {}, 1, 2) + spec = KernelGroupSpec( + name="rms_gemms_rope", + sub_launches=(sub,), # tuple — frozen dataclass + merged_arg_signature=("x_in", "norm_w", "normed"), + weight_slots=frozenset({1}), + intermediate_slots=frozenset({2}), + output_slots_for_validation=(2,), + baton_links=(), + ) + assert spec.name == "rms_gemms_rope" + assert len(spec.sub_launches) == 1 + + +def test_baton_link_consumer_must_follow_producer(): + """A baton link with consumer_idx <= producer_idx is meaningless; + spec dataclass tolerates it but a validator rejects.""" + from specs.kernel_group import validate_baton_links + sub_a = SubLaunchSpec("a", _dummy_builder, {}, 1, 2) + sub_b = SubLaunchSpec("b", _dummy_builder, {}, 1, 2) + bad = BatonLink(producer_idx=1, producer_out_slot=2, consumer_idx=0, consumer_in_slot=1) + with pytest.raises(ValueError, match="consumer_idx"): + validate_baton_links([sub_a, sub_b], [bad]) +``` + +- [ ] **Step 2: Run test, expect FAIL** + +Run: `cd programming_examples/llama32_1b/ablation/prefill && python3 -m pytest tests/test_kernel_group_spec.py -v` +Expected: `ModuleNotFoundError: No module named 'specs.kernel_group'`. + +- [ ] **Step 3: Implement `specs/kernel_group.py`** + +```python +"""Frozen dataclasses describing a multi-launch kernel-group's structure. + +A KernelGroupSpec is consumed by parameterized cells (cell_a/b/c/d) so that +the same cell logic works for any kernel-group whose spec is provided. +""" + +from dataclasses import dataclass +from typing import Callable + + +@dataclass(frozen=True) +class SubLaunchSpec: + """One sub-launch's standalone definition. + + Used by Cell A/B/C to invoke the sub-launch as its own xrt.run() call. + Cell D ignores SubLaunchSpec entirely (it uses the merged ELF). + """ + name: str # "rmsnorm" | "q_gemm" | "rope_q" | ... + builder_ref: Callable # returns a 1-launch mlir.Module at production shape + build_kwargs: dict # passed verbatim to builder_ref + weight_slot_in_standalone: int | None # arg slot of the standalone call holding the weight (or None) + output_slot_in_standalone: int # arg slot of the standalone call holding the output + + +@dataclass(frozen=True) +class BatonLink: + """An intermediate-BO alias to apply in Cell C. + + The producer's output BO becomes the consumer's input BO; the host + skips writing the consumer's input slot via intermediate_indices. + """ + producer_idx: int # index into KernelGroupSpec.sub_launches + producer_out_slot: int # output slot of producer's standalone signature + consumer_idx: int # index into KernelGroupSpec.sub_launches (must be > producer_idx) + consumer_in_slot: int # input slot of consumer's standalone signature + + +@dataclass(frozen=True) +class KernelGroupSpec: + """Full description of a multi-launch kernel-group for ablation.""" + name: str # "rms_gemms_rope" | "o_ffn" + sub_launches: tuple # tuple of SubLaunchSpec (frozen) + merged_arg_signature: tuple # tuple of arg-name strings matching production merged ELF args + weight_slots: frozenset # slots in merged signature that are weights/LUTs (Cell D static_input_indices) + intermediate_slots: frozenset # slots in merged signature that are kernel-overwritten intermediates + output_slots_for_validation: tuple # slots whose bytes go in the golden npz + baton_links: tuple # tuple of BatonLink (Cell C aliases these intermediate BOs) + + +def validate_baton_links(sub_launches, baton_links): + """Sanity check: each link's consumer must come after its producer in the sequence.""" + for link in baton_links: + if link.consumer_idx <= link.producer_idx: + raise ValueError( + f"baton link consumer_idx={link.consumer_idx} must be greater than " + f"producer_idx={link.producer_idx}" + ) + if link.producer_idx >= len(sub_launches): + raise ValueError(f"producer_idx {link.producer_idx} out of range") + if link.consumer_idx >= len(sub_launches): + raise ValueError(f"consumer_idx {link.consumer_idx} out of range") +``` + +- [ ] **Step 4: Re-run the test** + +Run: `cd programming_examples/llama32_1b/ablation/prefill && python3 -m pytest tests/test_kernel_group_spec.py -v` +Expected: 4 passed. + +- [ ] **Step 5: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/specs/ \ + programming_examples/llama32_1b/ablation/prefill/tests/test_kernel_group_spec.py +git commit -m "ablation/prefill: KernelGroupSpec/SubLaunchSpec/BatonLink dataclasses" +``` + +--- + +## Task 3: Concrete `KernelGroupSpec` for `rms_gemms_rope` + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/specs/rms_gemms_rope.py` + +**Reference:** Production builder at `programming_examples/llama32_1b/multi_launch_builder/rms_gemms_rope_multi.py:193`. Merged signature has 13 args (slots 0-12); see docstring at lines 211-228 of that file. Static slots: {1, 3, 5, 7, 9, 10}. Intermediate slots: {2, 4, 6, 8, 11, 12}. + +The 6 sub-launches: +| Idx | Name | Builder | Production-shape kwargs | weight_slot | output_slot | +|---|---|---|---|---|---| +| 0 | rmsnorm | `weighted_rms_norm.weighted_rms_norm.build_module` (wrapped via `_wrap_ir_in_launch`) | `seq_len=2048, emb_dim=2048, np_dtype=bfloat16, vector_size=16, herd_x=8` | 1 (norm_w) | 2 (normed) | +| 1 | q_gemm | `kernel_builder.gemm_builder._build_gemm_module` | `seq_len=2048, K=2048, N=2048, tile_m=64, tile_k_l2=64, tile_k_l1=32, tile_n=128, herd_m=8, herd_n=4` | 1 (W) | 2 (Y) | +| 2 | k_gemm | same | `seq_len=2048, K=2048, N=512, tile_m=64, tile_k_l2=64, tile_k_l1=32, tile_n=128, herd_m=8, herd_n=4` | 1 | 2 | +| 3 | v_gemm | same | `seq_len=2048, K=2048, N=512, tile_m=64, tile_k_l2=64, tile_k_l1=32, tile_n=128, herd_m=8, herd_n=4` | 1 | 2 | +| 4 | rope_q | `multi_launch_builder.rms_gemms_rope_multi._build_rope_2d` | `outer_rows=2048, outer_cols=2048, embed_dim=64, np_dtype=bfloat16, herd_x=8` | 1 (lut) | 2 (out) | +| 5 | rope_k | same | `outer_rows=2048, outer_cols=512, embed_dim=64, np_dtype=bfloat16, herd_x=8` | 1 | 2 | + +Baton links (within-group only; cross-group host hop is invariant per spec): +- (0, 2) → (1, 0) rmsnorm.normed → q_gemm.x (slot 0 of standalone gemm = the activation input) +- (0, 2) → (2, 0) rmsnorm.normed → k_gemm.x +- (0, 2) → (3, 0) rmsnorm.normed → v_gemm.x +- (1, 2) → (4, 0) q_gemm.q → rope_q.in +- (2, 2) → (5, 0) k_gemm.k → rope_k.in + +Note: the standalone GEMM signature (`_build_gemm_module`) per its docstring has args `(M, A, B, C)` — verify this in the actual file. If args are `(A, B, C)` then weight slot is 1 (B), activation slot is 0 (A), output slot is 2 (C). The implementer must inspect `kernel_builder/gemm_builder.py:107` to confirm slot positions before finalizing the spec. + +- [ ] **Step 1: Write the spec module** + +```python +"""Concrete KernelGroupSpec for the prefill rms_gemms_rope kernel-group. + +Mirrors the production stitch-spec in +multi_launch_builder/rms_gemms_rope_multi.py:467-474 (which lists the +arg mappings for the 6 sub-launches in the merged ELF). + +Slot conventions for standalones: + - rmsnorm: (x_in[seq, emb], norm_w[emb], out[seq, emb]) output at slot 2 + - gemm: (a[seq, K], b[K, N], c[seq, N]) output at slot 2 + (verify via kernel_builder/gemm_builder.py:107 — ordering may + be (M, A, B, C); if so, weight slot becomes 2 not 1.) + - rope_2d: (in_2d[rows, cols], lut_1d[N], out_2d[rows, cols]) output at slot 2 +""" + +from ml_dtypes import bfloat16 + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + + +def _build_rmsnorm_standalone(): + """Wrap weighted_rms_norm in air.launch+segment for solo invocation.""" + from weighted_rms_norm.weighted_rms_norm import build_module as build_rms + from kernel_builder.stitching import _wrap_ir_in_launch + from air.ir import Module + bare = str(build_rms(2048, 2048, bfloat16, 16, herd_x=8)) + wrapped_text = _wrap_ir_in_launch(bare) + return Module.parse(wrapped_text) + + +def _build_gemm_standalone(k, n): + """Production prefill GEMM: (seq=2048, k, n) with the production tile config. + + _build_gemm_module signature: (m, k, n, tile_m, tile_k_l2, tile_k_l1, tile_n, + herd_m, herd_n). Slots in standalone: 0=A (activation), 1=B (weight), 2=C (output). + """ + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + k, + n, + tile_m=64, + tile_k_l2=64, + tile_k_l1=32, + tile_n=128, + herd_m=8, + herd_n=4, + ) + + +def _build_rope_2d_standalone(outer_rows, outer_cols): + from multi_launch_builder.rms_gemms_rope_multi import _build_rope_2d + return _build_rope_2d(outer_rows, outer_cols, 64, bfloat16, herd_x=8) + + +SPEC = KernelGroupSpec( + name="rms_gemms_rope", + sub_launches=( + SubLaunchSpec("rmsnorm", _build_rmsnorm_standalone, {}, 1, 2), + SubLaunchSpec("q_gemm", _build_gemm_standalone, {"k": 2048, "n": 2048}, 1, 2), + SubLaunchSpec("k_gemm", _build_gemm_standalone, {"k": 2048, "n": 512}, 1, 2), + SubLaunchSpec("v_gemm", _build_gemm_standalone, {"k": 2048, "n": 512}, 1, 2), + SubLaunchSpec("rope_q", _build_rope_2d_standalone, {"outer_rows": 2048, "outer_cols": 2048}, 1, 2), + SubLaunchSpec("rope_k", _build_rope_2d_standalone, {"outer_rows": 2048, "outer_cols": 512}, 1, 2), + ), + merged_arg_signature=( + "x_in", "norm_w", "normed", + "wq", "q", + "wk", "k", + "wv", "v", + "lut_q", "lut_k", + "q_roped", "k_roped", + ), + weight_slots=frozenset({1, 3, 5, 7, 9, 10}), + intermediate_slots=frozenset({2, 4, 6, 8, 11, 12}), + output_slots_for_validation=(2, 4, 6, 8, 11, 12), + baton_links=( + BatonLink(producer_idx=0, producer_out_slot=2, consumer_idx=1, consumer_in_slot=0), # rmsnorm.normed -> q_gemm.x + BatonLink(producer_idx=0, producer_out_slot=2, consumer_idx=2, consumer_in_slot=0), # rmsnorm.normed -> k_gemm.x + BatonLink(producer_idx=0, producer_out_slot=2, consumer_idx=3, consumer_in_slot=0), # rmsnorm.normed -> v_gemm.x + BatonLink(producer_idx=1, producer_out_slot=2, consumer_idx=4, consumer_in_slot=0), # q_gemm.q -> rope_q.in + BatonLink(producer_idx=2, producer_out_slot=2, consumer_idx=5, consumer_in_slot=0), # k_gemm.k -> rope_k.in + ), +) +``` + +- [ ] **Step 2: Verify the spec validates** + +Run: +```bash +cd programming_examples/llama32_1b/ablation/prefill +python3 -c " +from specs.rms_gemms_rope import SPEC +from specs.kernel_group import validate_baton_links +validate_baton_links(SPEC.sub_launches, SPEC.baton_links) +print(f'{SPEC.name}: {len(SPEC.sub_launches)} sub-launches, {len(SPEC.baton_links)} baton links') +" +``` +Expected: `rms_gemms_rope: 6 sub-launches, 5 baton links`. + +If it errors on `_build_gemm_module` signature mismatch (e.g., the function takes positional args in a different order), fix the keyword arg names to match `kernel_builder/gemm_builder.py:107`. The implementer should read that function's signature first; if it requires an `M` parameter or has different defaults, adjust `_build_gemm_standalone` accordingly. + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/specs/rms_gemms_rope.py +git commit -m "ablation/prefill: concrete spec for rms_gemms_rope (6 sub-launches at seq=2048)" +``` + +--- + +## Task 4: Concrete `KernelGroupSpec` for `o_ffn` + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/specs/o_ffn.py` + +**Reference:** Production builder at `multi_launch_builder/o_ffn_multi.py:178`. Merged signature has 15 args (slots 0-14); see docstring at lines 209-228. Static slots: {1, 5, 7, 9, 12}. Intermediate slots: {2, 4, 6, 8, 10, 11, 13, 14}. Slot 0 (`attn_out`) and slot 3 (`x_residual`) are activation inputs (written every call). + +The 8 sub-launches per `o_ffn_multi.py`: +| Idx | Name | Builder | Production-shape kwargs | +|---|---|---|---| +| 0 | o_gemm | `_build_gemm_module` | `seq_len=2048, K=2048, N=2048, tile_m=64, tile_k_l2=256, tile_k_l1=32, tile_n=64, herd_m=8, herd_n=4` | +| 1 | res_add | `_build_add_2d_to_2d` | `seq_len=2048, emb_dim=2048, np_dtype=bfloat16` | +| 2 | ffn_rmsnorm | wrapped `weighted_rms_norm.build_module` | `seq_len=2048, emb_dim=2048, np_dtype=bfloat16, vector_size=16, herd_x=8` | +| 3 | gate_gemm | `_build_gemm_module` | `seq_len=2048, K=2048, N=8192, tile_m=64, tile_k_l2=64, tile_k_l1=32, tile_n=128, herd_m=8, herd_n=4` | +| 4 | up_gemm | same | `seq_len=2048, K=2048, N=8192, tile_m=64, tile_k_l2=64, tile_k_l1=32, tile_n=128, herd_m=8, herd_n=4` | +| 5 | swiglu | `ffn_swiglu.silu_and_mul.build_module` (or wrapped per existing usage) | `seq_len=2048, hidden_dim=8192, tile_n=4096, herd_x=8, herd_y=1, np_dtype=bfloat16` | +| 6 | down_gemm | `_build_gemm_module` | `seq_len=2048, K=8192, N=2048, tile_m=64, tile_k_l2=256, tile_k_l1=32, tile_n=64, herd_m=8, herd_n=4` | +| 7 | ffn_add | `_build_add_2d_to_2d` (or its 1D variant — verify via o_ffn_multi.py) | `seq_len=2048, emb_dim=2048, np_dtype=bfloat16` | + +Baton links (within-group): +- (0, 2) → (1, 0) o_gemm.proj → res_add.A (a 2D add takes 2 activation inputs + 1 output) +- (1, 2) → (2, 0) res_add.res1 → ffn_rmsnorm.x (and also feeds ffn_add later as residual) +- (2, 2) → (3, 0) ffn_rmsnorm.normed2 → gate_gemm.x +- (2, 2) → (4, 0) ffn_rmsnorm.normed2 → up_gemm.x +- (3, 2) → (5, 0) gate_gemm.gate → swiglu.gate +- (4, 2) → (5, 1) up_gemm.up → swiglu.up +- (5, 2) → (6, 0) swiglu.swiglu → down_gemm.x +- (6, 2) → (7, 0) down_gemm.down → ffn_add.A +- (1, 2) → (7, 1) res_add.res1 → ffn_add.B (residual-of-residual; verify against o_ffn_multi.py — the ffn_add's second input is the post-attention residual, which equals res1) + +The implementer should inspect `o_ffn_multi.py` to confirm sub-launch order, exact arg slot conventions for the 2D add and SwiGLU, and the residual connectivity in step 7. If `_build_add_2d_to_2d` takes 3 args `(A, B, C)` then activation inputs are slots 0 and 1, output is slot 2. SwiGLU's `silu_and_mul` typically takes `(gate, up, out)` — slot 0 is gate, slot 1 is up, slot 2 is output. + +- [ ] **Step 1: Read `o_ffn_multi.py:178-450`** to confirm the exact sub-builder signatures and arg-mapping (see the stitch-spec around line 350-400 of that file). + +- [ ] **Step 2: Write the spec module** + +> **Implementation note (post-execution wash-up):** Three deviations from the original spec were necessary: +> 1. SwiGLU import is `kernel_builder.ffn_swiglu.silu_and_mul.build_module_2d` (the 2D memref +> variant, signature `(rows, cols, tile_n, np_dtype, herd_x, herd_y)`) — not `ffn_swiglu.silu_and_mul.build_module`. +> It already emits `air.launch`; no `_wrap_ir_in_launch` needed. +> 2. `ffn_add` uses `_build_ffn_add_standalone` (replicated from the nested `_build_add_2d_to_1d` +> inside `o_ffn_multi.py`, which cannot be imported directly) — not `_build_add_2d_to_2d`. +> Its output is 1D `[n_total]` (2D inputs, 1D output). +> 3. `air.ir` does not export `T`; use `IntegerType.get_signless(32)` instead. + +```python +"""Concrete KernelGroupSpec for the prefill o_ffn kernel-group. + +Mirrors the production stitch-spec in multi_launch_builder/o_ffn_multi.py. +8 sequential launches at seq=2048, emb_dim=2048, hidden_dim=8192: + + L1 o_gemm [8,4] attn_out x wo -> proj + L2 res_add [8,1] proj + x_residual -> res1 (2D out) + L3 ffn_rmsnorm [8,1] res1 x ffn_norm_w -> normed2 + L4 gate_gemm [8,4] normed2 x w_gate -> gate + L5 up_gemm [8,4] normed2 x w_up -> up + L6 swiglu [8,1] SiLU(gate) x up -> swiglu + L7 down_gemm [8,4] swiglu x w_down -> down + L8 ffn_add [8,1] down + res1 -> output (1D out) + +15 merged-func args (slots 0-14); static slots {1,5,7,9,12}; +intermediate slots {2,4,6,8,10,11,13,14}. + +Slot conventions per sub-launch standalone signatures: + - gemm: (A[seq,K], B[K,N], C[seq,N]) weight=1, out=2 + - add_2d_to_2d: (A[seq,d], B[seq,d], C[seq,d]) no weight, out=2 + - rmsnorm: (x[seq,d], w[d], out[seq,d]) weight=1, out=2 + - swiglu_2d: (gate[seq,h], up[seq,h], out[seq,h]) no weight, out=2 + - ffn_add: (A[seq,d], B[seq,d], out[n_total]) no weight, out=2 +""" + +from ml_dtypes import bfloat16 + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + +# --------------------------------------------------------------------------- +# Sub-launch standalone builders +# --------------------------------------------------------------------------- + + +def _build_o_gemm_standalone(): + """O projection GEMM: attn_out(2048,2048) x wo(2048,2048) -> proj(2048,2048).""" + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + 2048, + 2048, + tile_m=64, + tile_k_l2=256, + tile_k_l1=32, + tile_n=64, + herd_m=8, + herd_n=4, + ) + + +def _build_res_add_standalone(): + """Residual add (2D->2D): proj + x_residual -> res1.""" + from multi_launch_builder.o_ffn_multi import _build_add_2d_to_2d + + return _build_add_2d_to_2d(2048, 2048, bfloat16) + + +def _build_rmsnorm_standalone(): + """FFN RMSNorm (bare herd -> wrap in air.launch).""" + from weighted_rms_norm.weighted_rms_norm import build_module as build_rms + from kernel_builder.stitching import _wrap_ir_in_launch + from air.ir import Module + + bare = str(build_rms(2048, 2048, bfloat16, 16, herd_x=8)) + return Module.parse(_wrap_ir_in_launch(bare)) + + +def _build_gateup_gemm_standalone(n): + """Gate or Up GEMM: normed2(2048,2048) x w(2048,n) -> out(2048,n).""" + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + 2048, + n, + tile_m=64, + tile_k_l2=64, + tile_k_l1=32, + tile_n=128, + herd_m=8, + herd_n=4, + ) + + +def _build_swiglu_standalone(): + """SwiGLU activation: SiLU(gate) * up -> swiglu (2D memref variant). + + Uses build_module_2d from kernel_builder/ffn_swiglu/silu_and_mul.py. + Signature: (rows, cols, tile_n, np_dtype_in, herd_x=8, herd_y=1). + Already wraps in air.launch -- no _wrap_ir_in_launch needed. + Arg slots in standalone: 0=gate, 1=up, 2=out. + """ + from kernel_builder.ffn_swiglu.silu_and_mul import build_module_2d as build_swiglu + + return build_swiglu(2048, 8192, 4096, bfloat16, herd_x=8, herd_y=1) + + +def _build_down_gemm_standalone(): + """Down GEMM: swiglu(2048,8192) x w_down(8192,2048) -> down(2048,2048).""" + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + 8192, + 2048, + tile_m=64, + tile_k_l2=256, + tile_k_l1=32, + tile_n=64, + herd_m=8, + herd_n=4, + ) + + +def _build_ffn_add_standalone(): + """FFN Add (2D inputs -> 1D output): down + res1 -> output[n_total]. + + Replicated from the nested _build_add_2d_to_1d() in o_ffn_multi.py + (that function is defined inline inside build_o_ffn_module and cannot + be imported directly). + + Arg slots: 0=A (down, 2D), 1=B (res1, 2D), 2=out (1D). + """ + from air.ir import ( + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineMapAttr, + AffineSymbolExpr, + IntegerAttr, + IntegerType, + MemRefType, + VectorType, + UnitAttr, + StringAttr, + ) + from air.dialects.affine import apply as affine_apply + from air.dialects.air import launch, segment, herd, module_builder + from air.dialects.memref import ( + collapse_shape as memref_collapse_shape, + AllocOp, + DeallocOp, + subview, + ) + from air.dialects.func import FuncOp + from air.dialects.scf import for_, yield_ + from air.dialects import arith + from air.dialects.vector import transfer_read, transfer_write + from air.backend.xrt_runner import type_mapper + from air.dialects.air import MemorySpace + + seq_len = 2048 + emb_dim = 2048 + n_total = seq_len * emb_dim + total_tiles = 8 + chunk_size = n_total // total_tiles + tile_n = emb_dim + + @module_builder + def _build(): + xrt_dtype = type_mapper(bfloat16) + l3_2d_ty = MemRefType.get([seq_len, emb_dim], xrt_dtype) + l3_1d_ty = MemRefType.get([n_total], xrt_dtype) + l1_space = IntegerAttr.get(IntegerType.get_signless(32), MemorySpace.L1) + l1_ty = MemRefType.get([tile_n], xrt_dtype, memory_space=l1_space) + vec_ty = VectorType.get([16], xrt_dtype) + identity_map = AffineMapAttr.get(AffineMap.get_identity(1)) + + @FuncOp.from_py_func(l3_2d_ty, l3_2d_ty, l3_1d_ty) + def eltwise_add(a_2d, b_2d, out_1d): + @launch(operands=[a_2d, b_2d, out_1d]) + def add_launch(l_a, l_b, l_out): + a_flat = memref_collapse_shape(l3_1d_ty, l_a, [[0, 1]]) + b_flat = memref_collapse_shape(l3_1d_ty, l_b, [[0, 1]]) + + @segment(name="add_seg", operands=[a_flat, b_flat, l_out]) + def add_seg(s_a, s_b, s_out): + offset_map = AffineMap.get( + 0, + 3, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(0), + AffineExpr.get_mul( + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(1), + ), + AffineSymbolExpr.get(2), + ), + AffineConstantExpr.get(chunk_size), + ), + ) + ], + ) + + @herd( + name="add_herd", + sizes=[8, 1], + operands=[s_a, s_b, s_out], + ) + def add_body(_tx, _ty, _sx, _sy, h_a, h_b, h_out): + l1_a = AllocOp(l1_ty, [], []) + l1_b = AllocOp(l1_ty, [], []) + l1_out = AllocOp(l1_ty, [], []) + c0 = arith.ConstantOp.create_index(0) + cst0 = arith.ConstantOp(xrt_dtype, 0.0) + for loop_iv in for_(0, chunk_size, tile_n): + offset = affine_apply(offset_map, [loop_iv, _tx, _ty]) + from air.dialects.air import dma_memcpy_nd + + dma_memcpy_nd( + l1_a, + h_a, + src_offsets=[offset], + src_sizes=[tile_n], + src_strides=[1], + ) + dma_memcpy_nd( + l1_b, + h_b, + src_offsets=[offset], + src_sizes=[tile_n], + src_strides=[1], + ) + for j in for_(0, tile_n, 16): + sub_a = subview(l1_a.result, [j], [16], [1]) + sub_b = subview(l1_b.result, [j], [16], [1]) + sub_out = subview(l1_out.result, [j], [16], [1]) + v_a = transfer_read( + vec_ty, sub_a, [c0], identity_map, cst0, [True] + ) + v_b = transfer_read( + vec_ty, sub_b, [c0], identity_map, cst0, [True] + ) + v_sum = arith.addf(v_a, v_b) + transfer_write( + None, v_sum, sub_out, [c0], identity_map, [True] + ) + yield_([]) + dma_memcpy_nd( + h_out, + l1_out, + dst_offsets=[offset], + dst_sizes=[tile_n], + dst_strides=[1], + ) + yield_([]) + DeallocOp(l1_a) + DeallocOp(l1_b) + DeallocOp(l1_out) + + return _build() + + +# --------------------------------------------------------------------------- +# KernelGroupSpec +# --------------------------------------------------------------------------- + +SPEC = KernelGroupSpec( + name="o_ffn", + sub_launches=( + # idx=0: O GEMM -- weight at slot 1 (wo), output at slot 2 (proj) + SubLaunchSpec("o_gemm", _build_o_gemm_standalone, {}, 1, 2), + # idx=1: Res Add -- no weight, output at slot 2 (res1[2D]) + SubLaunchSpec("res_add", _build_res_add_standalone, {}, None, 2), + # idx=2: FFN RMSNorm -- weight at slot 1 (ffn_norm_w), output at slot 2 (normed2) + SubLaunchSpec("ffn_rmsnorm", _build_rmsnorm_standalone, {}, 1, 2), + # idx=3: Gate GEMM -- weight at slot 1 (w_gate), output at slot 2 (gate) + SubLaunchSpec("gate_gemm", _build_gateup_gemm_standalone, {"n": 8192}, 1, 2), + # idx=4: Up GEMM -- weight at slot 1 (w_up), output at slot 2 (up) + SubLaunchSpec("up_gemm", _build_gateup_gemm_standalone, {"n": 8192}, 1, 2), + # idx=5: SwiGLU -- no weight, gate=slot0, up=slot1, output at slot 2 + SubLaunchSpec("swiglu", _build_swiglu_standalone, {}, None, 2), + # idx=6: Down GEMM -- weight at slot 1 (w_down), output at slot 2 (down) + SubLaunchSpec("down_gemm", _build_down_gemm_standalone, {}, 1, 2), + # idx=7: FFN Add -- no weight, A=slot0 (down), B=slot1 (res1), output at slot 2 + SubLaunchSpec("ffn_add", _build_ffn_add_standalone, {}, None, 2), + ), + merged_arg_signature=( + "attn_out", # 0 activation input + "wo", # 1 weight (static) + "proj", # 2 intermediate + "x_residual", # 3 activation input + "res1", # 4 intermediate (shared: res_add out + ffn_add B) + "ffn_norm_w", # 5 weight (static) + "normed2", # 6 intermediate + "w_gate", # 7 weight (static) + "gate", # 8 intermediate + "w_up", # 9 weight (static) + "up", # 10 intermediate + "swiglu", # 11 intermediate + "w_down", # 12 weight (static) + "down", # 13 intermediate + "output", # 14 intermediate (final 1D output) + ), + weight_slots=frozenset({1, 5, 7, 9, 12}), + intermediate_slots=frozenset({2, 4, 6, 8, 10, 11, 13, 14}), + output_slots_for_validation=(14,), + baton_links=( + # Stitch arg_map verified against o_ffn_multi.py lines 457-465: + # L1 {0:0,1:1,2:2} L2 {0:2,1:3,2:4} L3 {0:4,1:5,2:6} + # L4 {0:6,1:7,2:8} L5 {0:6,1:9,2:10} L6 {0:8,1:10,2:11} + # L7 {0:11,1:12,2:13} L8 {0:13,1:4,2:14} + BatonLink(0, 2, 1, 0), # o_gemm.proj (slot2) -> res_add.A (slot0) + BatonLink(1, 2, 2, 0), # res_add.res1 (slot2) -> ffn_rmsnorm.x (slot0) + BatonLink(2, 2, 3, 0), # ffn_rmsnorm.normed2 (slot2) -> gate_gemm.x (slot0) + BatonLink(2, 2, 4, 0), # ffn_rmsnorm.normed2 (slot2) -> up_gemm.x (slot0) + BatonLink(3, 2, 5, 0), # gate_gemm.gate (slot2) -> swiglu.gate (slot0) + BatonLink(4, 2, 5, 1), # up_gemm.up (slot2) -> swiglu.up (slot1) + BatonLink(5, 2, 6, 0), # swiglu.swiglu (slot2) -> down_gemm.x (slot0) + BatonLink(6, 2, 7, 0), # down_gemm.down (slot2) -> ffn_add.A (slot0) + BatonLink( + 1, 2, 7, 1 + ), # res_add.res1 (slot2) -> ffn_add.B (slot1) [residual-of-residual] + ), +) +``` + +- [ ] **Step 3: Verify the spec** + +```bash +cd programming_examples/llama32_1b/ablation/prefill +python3 -c " +from specs.o_ffn import SPEC +from specs.kernel_group import validate_baton_links +validate_baton_links(SPEC.sub_launches, SPEC.baton_links) +print(f'{SPEC.name}: {len(SPEC.sub_launches)} sub-launches, {len(SPEC.baton_links)} baton links') +" +``` +Expected: `o_ffn: 8 sub-launches, 9 baton links`. If any sub-builder import fails, the implementer must adjust the standalone helpers per the actual production code in `o_ffn_multi.py`. + +- [ ] **Step 4: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/specs/o_ffn.py +git commit -m "ablation/prefill: concrete spec for o_ffn (8 sub-launches at seq=2048)" +``` + +--- + +## Phase 2 — Standalone Builders + Compile (Tasks 5–7) + +## Task 5: Standalone builders for `rms_gemms_rope` + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/standalone_builders/rms_gemms_rope.py` + +This is a thin wrapper file. Most of the build logic lives in `specs/rms_gemms_rope.py` (the `_build_*_standalone` helpers). This file just re-exports a `STANDALONES` registry compatible with the compile harness in T7. + +- [ ] **Step 1: Write the file** + +```python +"""Single-launch standalone modules for the prefill rms_gemms_rope kernel-group. + +Exports a STANDALONES registry compatible with cells/common.py:compile_standalone_kernels. +Each entry: (name, build_fn, build_kwargs). +""" + +from specs.rms_gemms_rope import SPEC + + +STANDALONES = [ + (sub.name, sub.builder_ref, sub.build_kwargs) + for sub in SPEC.sub_launches +] +``` + +- [ ] **Step 2: Verify the registry** + +```bash +cd programming_examples/llama32_1b/ablation/prefill +python3 -c " +from standalone_builders.rms_gemms_rope import STANDALONES +assert len(STANDALONES) == 6, f'expected 6, got {len(STANDALONES)}' +for name, build_fn, kwargs in STANDALONES: + print(f'{name}: {build_fn.__name__}({kwargs})') +" +``` +Expected: 6 lines listing rmsnorm, q_gemm, k_gemm, v_gemm, rope_q, rope_k with their kwargs. + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/standalone_builders/rms_gemms_rope.py +git commit -m "ablation/prefill: standalone STANDALONES registry for rms_gemms_rope" +``` + +--- + +## Task 6: Standalone builders for `o_ffn` + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/standalone_builders/o_ffn.py` + +Identical pattern to T5; only the spec module differs. + +- [ ] **Step 1: Write the file** + +```python +"""Single-launch standalone modules for the prefill o_ffn kernel-group. + +Exports a STANDALONES registry compatible with cells/common.py:compile_standalone_kernels. +""" + +from specs.o_ffn import SPEC + + +STANDALONES = [ + (sub.name, sub.builder_ref, sub.build_kwargs) + for sub in SPEC.sub_launches +] +``` + +- [ ] **Step 2: Verify** + +```bash +cd programming_examples/llama32_1b/ablation/prefill +python3 -c " +from standalone_builders.o_ffn import STANDALONES +assert len(STANDALONES) == 8, f'expected 8, got {len(STANDALONES)}' +for name, build_fn, kwargs in STANDALONES: + print(f'{name}: {build_fn.__name__}({kwargs})') +" +``` +Expected: 8 lines. + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/standalone_builders/o_ffn.py +git commit -m "ablation/prefill: standalone STANDALONES registry for o_ffn" +``` + +--- + +## Task 7: Compile harness — `cells/common.py` + actual compile + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/common.py` +- Create: `programming_examples/llama32_1b/ablation/prefill/.gitignore` + +This file mirrors Plan 1's `cells/common.py` (lifting the `_extract_public_func_name` regex, `compile_standalone_kernels`, `_share_bo`, `standalone_backend_kwargs` helpers). The only difference: the compile harness uses one of two prefill backends (RMS_GEMMS_ROPE_BACKEND or O_FFN_BACKEND) per kernel-group. + +- [ ] **Step 1: Write `cells/common.py`** + +> **Implementation note (post-execution wash-up):** `compile_standalone_kernels` must wrap +> `build_fn(**kwargs)` in a `with MLIRContext():` block; without it the MLIR module +> parse context is missing and the builder crashes. Also note that +> `programming_examples/llama32_1b/kernel_builder/external_kernels.py` was modified +> alongside this task to add an `MLIR_AIE_INSTALL_DIR` env-var fallback for worktree +> path resolution — that change is a candidate for cherry-picking back to `llama-3.2-1B-devel` +> independently of the ablation work. + +```python +"""Shared helpers for prefill ablation cells. + +Lifted (and extended for two-backend support) from Plan 1's +ablation/cells/common.py. The original Plan 1 file is read-only. + +- compile_standalone_kernels(cache, group_name, registry, backend_preset): + Compile every standalone in `registry` into `cache`, using the actual + public func name extracted from the MLIR module as instance_name. +- _extract_public_func_name(mlir_text): regex over the module string. +- _share_bo(cache, src_key, src_slot, dst_key, dst_slot): alias cached BOs + for Cell C's baton-pass. +- standalone_backend_kwargs(backend_preset, verbose): returns backend kwargs + with instance_name removed (set per-kernel by compile_standalone_kernels). +""" + +import re + +from air.ir import Context as MLIRContext + +from kernel_builder.cache import KernelCache + + +def _extract_public_func_name(mlir_text): + """Find the first non-private `func.func @` in the module text.""" + for line in mlir_text.split("\n"): + if "func.func @" in line and "private" not in line: + m = re.search(r"@(\w+)", line) + if m: + return m.group(1) + raise ValueError("no public func.func found in module") + + +def standalone_backend_kwargs(backend_preset, verbose=False): + """Backend kwargs with instance_name removed (set per-kernel by caller).""" + base = {**backend_preset, "verbose": verbose} + base.pop("instance_name", None) + return base + + +def compile_standalone_kernels( + cache: KernelCache, group_name: str, registry, backend_preset +): + """Compile every standalone in `registry` into `cache` under names + f"{group_name}__{name}". Skip any kernel already in cache.artifacts. + + Each registry entry: (name, build_fn, build_kwargs). + """ + for name, build_fn, kwargs in registry: + kernel_name = f"{group_name}__{name}" + if kernel_name in cache.artifacts: + continue + with MLIRContext(): + mlir_module = build_fn(**kwargs) + public_func = _extract_public_func_name(str(mlir_module)) + be = standalone_backend_kwargs(backend_preset, verbose=cache.verbose) + be["instance_name"] = public_func + cache.compile_and_cache(kernel_name, mlir_module, be) + cache._save_manifest() + + +def _share_bo(cache, src_key, src_slot, dst_key, dst_slot): + """Replace cached BO at (dst_key, dst_slot) with the same xrt.bo as + (src_key, src_slot). Only valid after both kernels' first call has + materialized BOs.""" + src_bos = cache._cached_bos[src_key] + dst_bos = cache._cached_bos[dst_key] + dst_bos[dst_slot] = src_bos[src_slot] + + +def main(): + """python3 -m cells.common — compile both kernel-groups' standalones.""" + from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND, O_FFN_BACKEND + from standalone_builders.rms_gemms_rope import STANDALONES as RMS_STD + from standalone_builders.o_ffn import STANDALONES as O_STD + + cache = KernelCache(cache_dir="standalone_cache", verbose=True) + cache.load_manifest() + compile_standalone_kernels(cache, "rms_gemms_rope", RMS_STD, RMS_GEMMS_ROPE_BACKEND) + compile_standalone_kernels(cache, "o_ffn", O_STD, O_FFN_BACKEND) + print(f"Compiled {len(cache.artifacts)} standalone ELFs.") + + +if __name__ == "__main__": + main() +``` + +- [ ] **Step 2: Add `.gitignore`** + +```bash +echo "build/" > programming_examples/llama32_1b/ablation/prefill/.gitignore +echo "standalone_cache/" >> programming_examples/llama32_1b/ablation/prefill/.gitignore +echo "results_*.json" >> programming_examples/llama32_1b/ablation/prefill/.gitignore +echo "report_*.md" >> programming_examples/llama32_1b/ablation/prefill/.gitignore +``` + +- [ ] **Step 3: Run the compile (one-time, ~10–15 min for 14 ELFs at seq=2048)** + +```bash +cd programming_examples/llama32_1b/ablation/prefill +mkdir -p build && cd build +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -m cells.common +``` + +Expected output: 14 lines `Compiled rms_gemms_rope__: s` and `Compiled o_ffn__: s`. **NO `instance_name ... does not match` warnings** (the `_extract_public_func_name` regex prevents that — see Plan 1 T6 wash-up). + +- [ ] **Step 4: Verify the manifest** + +```bash +python3 -c " +import json +with open('standalone_cache/manifest.json') as f: + m = json.load(f) +assert len(m) == 14, f'expected 14, got {len(m)}' +for name, info in sorted(m.items()): + assert info['kernel'].startswith('main:'), f'bad kernel ref: {info[\"kernel\"]}' +print(f'manifest OK: {len(m)} entries') +" +``` +Expected: `manifest OK: 14 entries`. + +- [ ] **Step 5: Commit (source + .gitignore only; no binaries)** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/common.py \ + programming_examples/llama32_1b/ablation/prefill/.gitignore +git commit -m "ablation/prefill: compile harness for both kernel-groups (14 ELFs)" +``` + +--- + +## Phase 3 — Cells + Golden + Validation + FA (Tasks 8–11) + +## Task 8: Cell D — production wrapper for both kernel-groups + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/cell_d_merged.py` + +Two cell-D entry points (one per kernel-group). Each compiles the production merged ELF (if not cached) and provides a `run_cell_d_(cache, layer_inputs, layer_idx)` function returning the same dict shape Plan 1 used. + +- [ ] **Step 1: Write cell_d_merged.py** + +```python +"""Cell D — production: invoke the merged ELFs (rms_gemms_rope.elf with 6 +launches; o_ffn.elf with 8 launches) using the production KernelCache + +backend presets. +""" + +import os +import sys + +# Ensure llama32_1b/ is on sys.path so kernel_builder and multi_launch_builder +# are importable whether this file is run directly or imported from the +# prefill/ package root. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_LLAMA_DIR = os.path.normpath(os.path.join(_THIS_DIR, "..", "..", "..")) +if _LLAMA_DIR not in sys.path: + sys.path.insert(0, _LLAMA_DIR) + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND, O_FFN_BACKEND +from multi_launch_builder.rms_gemms_rope_multi import build_rms_gemms_rope_module +from multi_launch_builder.o_ffn_multi import build_o_ffn_module + +CONFIG = { + "seq_len": 2048, + "emb_dim": 2048, + "kv_dim": 512, + "n_heads": 32, + "n_kv_heads": 8, + "head_dim": 64, + "hidden_dim": 8192, +} + + +def compile_cell_d_rms_gemms_rope(cache: KernelCache): + if "rms_gemms_rope" in cache.artifacts: + return + mod = build_rms_gemms_rope_module( + seq_len=CONFIG["seq_len"], emb_dim=CONFIG["emb_dim"], + kv_dim=CONFIG["kv_dim"], n_heads=CONFIG["n_heads"], + n_kv_heads=CONFIG["n_kv_heads"], head_dim=CONFIG["head_dim"], + ) + cache.compile_and_cache("rms_gemms_rope", mod, + {"verbose": cache.verbose, **RMS_GEMMS_ROPE_BACKEND}) + cache._save_manifest() + + +def compile_cell_d_o_ffn(cache: KernelCache): + if "o_ffn" in cache.artifacts: + return + mod = build_o_ffn_module( + seq_len=CONFIG["seq_len"], emb_dim=CONFIG["emb_dim"], + hidden_dim=CONFIG["hidden_dim"], + ) + cache.compile_and_cache("o_ffn", mod, + {"verbose": cache.verbose, **O_FFN_BACKEND}) + cache._save_manifest() + + +def run_cell_d_rms_gemms_rope(cache, layer_inputs, layer_idx=0): + """One rms_gemms_rope call (6 launches in one xrt.run). + layer_inputs has keys: x_in, norm_w, wq, wk, wv, lut_q, lut_k. + Returns dict with normed, q, k, v, q_roped, k_roped, _wall_s. + """ + seq = CONFIG["seq_len"]; emb = CONFIG["emb_dim"]; kv = CONFIG["kv_dim"] + args = [ + layer_inputs["x_in"], + layer_inputs["norm_w"], + np.zeros((seq, emb), dtype=bfloat16), # normed + layer_inputs["wq"], + np.zeros((seq, emb), dtype=bfloat16), # q + layer_inputs["wk"], + np.zeros((seq, kv), dtype=bfloat16), # k + layer_inputs["wv"], + np.zeros((seq, kv), dtype=bfloat16), # v + layer_inputs["lut_q"], + layer_inputs["lut_k"], + np.zeros((seq, emb), dtype=bfloat16), # q_roped + np.zeros((seq, kv), dtype=bfloat16), # k_roped + ] + t0 = time.perf_counter() + out = cache.load_and_run( + "rms_gemms_rope", RMS_GEMMS_ROPE_BACKEND, + *args, + output_indices=[2, 4, 6, 8, 11, 12], + static_input_indices={1, 3, 5, 7, 9, 10}, + intermediate_indices={2, 4, 6, 8, 11, 12}, + bo_key=f"D_rms_gemms_rope_L{layer_idx}", + ) + elapsed = time.perf_counter() - t0 + return { + "normed": out[2], "q": out[4], "k": out[6], "v": out[8], + "q_roped": out[11], "k_roped": out[12], + "_wall_s": elapsed, + } + + +def run_cell_d_o_ffn(cache, layer_inputs, layer_idx=0): + """One o_ffn call (8 launches in one xrt.run). + layer_inputs has: attn_out, wo, x_residual, ffn_norm_w, w_gate, w_up, w_down. + Returns dict with output, _wall_s. + """ + seq = CONFIG["seq_len"]; emb = CONFIG["emb_dim"]; hid = CONFIG["hidden_dim"] + n_total = seq * emb + args = [ + layer_inputs["attn_out"], # 0 + layer_inputs["wo"], # 1 + np.zeros((seq, emb), dtype=bfloat16), # 2 proj + layer_inputs["x_residual"], # 3 + np.zeros((seq, emb), dtype=bfloat16), # 4 res1 + layer_inputs["ffn_norm_w"], # 5 + np.zeros((seq, emb), dtype=bfloat16), # 6 normed2 + layer_inputs["w_gate"], # 7 + np.zeros((seq, hid), dtype=bfloat16), # 8 gate + layer_inputs["w_up"], # 9 + np.zeros((seq, hid), dtype=bfloat16), # 10 up + np.zeros((seq, hid), dtype=bfloat16), # 11 swiglu + layer_inputs["w_down"], # 12 + np.zeros((seq, emb), dtype=bfloat16), # 13 down + np.zeros(n_total, dtype=bfloat16), # 14 output (1D) + ] + t0 = time.perf_counter() + out = cache.load_and_run( + "o_ffn", O_FFN_BACKEND, + *args, + output_indices=[14], + static_input_indices={1, 5, 7, 9, 12}, + intermediate_indices={2, 4, 6, 8, 10, 11, 13, 14}, + bo_key=f"D_o_ffn_L{layer_idx}", + ) + return {"output": out[14], "_wall_s": time.perf_counter() - t0} +``` + +- [ ] **Step 2: Verify import + signature** + +```bash +cd programming_examples/llama32_1b/ablation/prefill +python3 -c " +from cells.cell_d_merged import (compile_cell_d_rms_gemms_rope, + compile_cell_d_o_ffn, + run_cell_d_rms_gemms_rope, + run_cell_d_o_ffn, CONFIG) +print('OK', CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['hidden_dim']) +" +``` +Expected: `OK 2048 2048 8192`. + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/cell_d_merged.py +git commit -m "ablation/prefill: Cell D wrappers for rms_gemms_rope and o_ffn merged ELFs" +``` + +--- + +## Task 9: Golden fixture generator + commit + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/golden/regen_golden.py` +- Generate + commit: `golden/golden_rms_gemms_rope_prefill.npz`, `golden/golden_o_ffn_prefill.npz`, `golden/golden_meta.json` + +- [ ] **Step 1: Write `regen_golden.py`** + +```python +"""Regenerate prefill golden fixtures by running Cell D once for each kernel-group. + +Uses deterministic synthetic inputs (numpy seed=42 for layer 0). +Outputs: + golden/golden_rms_gemms_rope_prefill.npz + golden/golden_o_ffn_prefill.npz + golden/golden_meta.json +""" + +import hashlib +import json +import os +import sys + +import numpy as np +from ml_dtypes import bfloat16 + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from kernel_builder.cache import KernelCache +from cells.cell_d_merged import ( + CONFIG, + compile_cell_d_rms_gemms_rope, compile_cell_d_o_ffn, + run_cell_d_rms_gemms_rope, run_cell_d_o_ffn, +) + + +def _synthetic_layer_inputs(layer_idx, config): + """Deterministic synthetic inputs for one prefill layer (seq=2048). + + Same seeding scheme as Plan 1: seed = 42 + layer_idx. + """ + rng = np.random.default_rng(42 + layer_idx) + seq = config["seq_len"]; emb = config["emb_dim"] + kv = config["kv_dim"]; hid = config["hidden_dim"] + return { + "x_in": rng.standard_normal((seq, emb)).astype(bfloat16), + "norm_w": rng.standard_normal(emb).astype(bfloat16), + "wq": rng.standard_normal((emb, emb)).astype(bfloat16), + "wk": rng.standard_normal((emb, kv)).astype(bfloat16), + "wv": rng.standard_normal((emb, kv)).astype(bfloat16), + "lut_q": rng.standard_normal(seq * emb).astype(bfloat16), + "lut_k": rng.standard_normal(seq * kv).astype(bfloat16), + "wo": rng.standard_normal((emb, emb)).astype(bfloat16), + "ffn_norm_w": rng.standard_normal(emb).astype(bfloat16), + "w_gate": rng.standard_normal((emb, hid)).astype(bfloat16), + "w_up": rng.standard_normal((emb, hid)).astype(bfloat16), + "w_down": rng.standard_normal((hid, emb)).astype(bfloat16), + } + + +def main(): + cache = KernelCache(cache_dir="standalone_cache", verbose=True) + cache.load_manifest() + compile_cell_d_rms_gemms_rope(cache) + compile_cell_d_o_ffn(cache) + + inputs = _synthetic_layer_inputs(0, CONFIG) + + # rms_gemms_rope golden + rg_inputs = {k: inputs[k] for k in ["x_in","norm_w","wq","wk","wv","lut_q","lut_k"]} + rg_out = run_cell_d_rms_gemms_rope(cache, rg_inputs, layer_idx=0) + rg_path = os.path.join(os.path.dirname(__file__), "golden_rms_gemms_rope_prefill.npz") + np.savez(rg_path, **{k: v for k, v in rg_out.items() if not k.startswith("_")}) + + # For o_ffn golden, attn_out comes from FA in production. For the golden + # we use a CPU FA reference computed from rg_out's q_roped/k_roped/v — + # since FA is invariant across cells, all cells will see the same attn_out. + # Simplest: synthesize attn_out from the same RNG (it is what flows into + # o_ffn's slot 0 in every cell; the bytes are determined upstream). + attn_out = np.random.default_rng(42 + 0 + 1000).standard_normal( + (CONFIG["seq_len"], CONFIG["emb_dim"])).astype(bfloat16) + of_inputs = { + "attn_out": attn_out, + "wo": inputs["wo"], + "x_residual": inputs["x_in"], # the residual is the layer input + "ffn_norm_w": inputs["ffn_norm_w"], + "w_gate": inputs["w_gate"], + "w_up": inputs["w_up"], + "w_down": inputs["w_down"], + } + of_out = run_cell_d_o_ffn(cache, of_inputs, layer_idx=0) + of_path = os.path.join(os.path.dirname(__file__), "golden_o_ffn_prefill.npz") + np.savez(of_path, **{k: v for k, v in of_out.items() if not k.startswith("_")}) + + meta = { + "config": CONFIG, + "rms_gemms_rope": { + "input_hashes": {k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in rg_inputs.items()}, + "output_hashes": {k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in rg_out.items() if not k.startswith("_")}, + }, + "o_ffn": { + "input_hashes": {k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in of_inputs.items()}, + "output_hashes": {k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in of_out.items() if not k.startswith("_")}, + }, + } + with open(os.path.join(os.path.dirname(__file__), "golden_meta.json"), "w") as f: + json.dump(meta, f, indent=2) + print(f"Wrote {rg_path}, {of_path}, golden_meta.json") + + +if __name__ == "__main__": + main() +``` + +- [ ] **Step 2: Run the generator** + +```bash +cd programming_examples/llama32_1b/ablation/prefill/build +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 ../golden/regen_golden.py +``` + +Expected: 2 ELFs compiled (rms_gemms_rope ~30s, o_ffn ~50s if not cached), then `Wrote .../golden_rms_gemms_rope_prefill.npz, .../golden_o_ffn_prefill.npz, golden_meta.json`. The two npz files together should be a few MB (six 2048×N arrays + one 2048×2048 output = ~16-32 MB total). + +- [ ] **Step 3: Verify fixtures** + +```bash +ls -la programming_examples/llama32_1b/ablation/prefill/golden/ +python3 -c " +import numpy as np +rg = np.load('programming_examples/llama32_1b/ablation/prefill/golden/golden_rms_gemms_rope_prefill.npz') +of = np.load('programming_examples/llama32_1b/ablation/prefill/golden/golden_o_ffn_prefill.npz') +print('rg files:', list(rg.files)) +print('of files:', list(of.files)) +" +``` +Expected: rg has 6 arrays (normed, q, k, v, q_roped, k_roped); of has 1 array (output). + +- [ ] **Step 4: Commit fixtures** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/golden/ +git commit -m "ablation/prefill: golden fixtures from Cell D for rms_gemms_rope and o_ffn" +``` + +--- + +## Task 10: Validation gate (reuse Plan 1 + new test) + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/tests/test_validation_gate.py` + +We **reuse Plan 1's `validate.py` verbatim** (no copy). Plan 1's `validate_against_golden(cell_outputs, golden_dir)` reads from `/golden_rms_gemv_rope.npz` though — so we either pass a different filename or accept Plan 1's logic. + +The simplest: lift the validate logic into a small `prefill/validate.py` that takes a `golden_npz_filename` parameter so we can reuse it for both kernel-groups' goldens. + +- [ ] **Step 1: Create `prefill/validate.py` (lifted from Plan 1, parameterized)** + +```python +"""Per-cell validation — parameterized version of Plan 1's validate.py. + +Plan 1's validate.py hardcodes the golden filename to +"golden_rms_gemv_rope.npz". For prefill we have two goldens, so we +parameterize the filename. The byte-equality contract is identical. +""" + +import os + +import numpy as np + +# Reuse the exception class from Plan 1 if available; redefine if not. +try: + from validate import GoldenMismatch # Plan 1's exception +except ImportError: + class GoldenMismatch(AssertionError): + pass + + +def validate_against_golden(cell_outputs: dict, golden_dir: str, npz_filename: str): + """Compare every key in cell_outputs to the matching array in + /. Raise GoldenMismatch on any diff.""" + npz = np.load(os.path.join(golden_dir, npz_filename)) + for key in npz.files: + if key not in cell_outputs: + raise GoldenMismatch(f"cell missing output '{key}'") + gv = npz[key] + cv = cell_outputs[key] + if cv.shape != gv.shape: + raise GoldenMismatch(f"{key}: shape mismatch cell={cv.shape} golden={gv.shape}") + if cv.dtype.itemsize != gv.dtype.itemsize: + raise GoldenMismatch(f"{key}: itemsize mismatch") + if cv.tobytes() != gv.tobytes(): + from ml_dtypes import bfloat16 as _bf16 + cf = cv.view(np.uint8).view(_bf16).astype(np.float32) if cv.dtype != np.float32 else cv + gf = gv.view(np.uint8).view(_bf16).astype(np.float32) if gv.dtype != np.float32 else gv + max_abs = float(np.max(np.abs(cf - gf))) + max_rel = float(np.max(np.abs((cf - gf) / (np.abs(gf) + 1e-9)))) + raise GoldenMismatch(f"{key}: byte mismatch max_abs={max_abs:.4g} max_rel={max_rel:.4g}") +``` + +- [ ] **Step 2: Write the test** + +`prefill/tests/test_validation_gate.py`: + +```python +"""Test the prefill validation gate against the committed goldens.""" + +import os + +import numpy as np +import pytest +from ml_dtypes import bfloat16 + +from validate import validate_against_golden, GoldenMismatch + +GOLDEN_DIR = os.path.join(os.path.dirname(__file__), "..", "golden") + + +def _load(filename): + npz = np.load(os.path.join(GOLDEN_DIR, filename)) + return {k: npz[k] for k in npz.files} + + +def test_rms_gemms_rope_passes_on_exact_match(): + g = _load("golden_rms_gemms_rope_prefill.npz") + validate_against_golden(g, GOLDEN_DIR, "golden_rms_gemms_rope_prefill.npz") + + +def test_rms_gemms_rope_raises_on_byte_diff(): + g = _load("golden_rms_gemms_rope_prefill.npz") + perturbed = {k: v.copy() for k, v in g.items()} + arr = perturbed["normed"].view(np.uint8).copy() + arr[0] ^= 0x01 + perturbed["normed"] = arr.view(bfloat16).reshape(g["normed"].shape) + with pytest.raises(GoldenMismatch, match="normed"): + validate_against_golden(perturbed, GOLDEN_DIR, "golden_rms_gemms_rope_prefill.npz") + + +def test_o_ffn_passes_on_exact_match(): + g = _load("golden_o_ffn_prefill.npz") + validate_against_golden(g, GOLDEN_DIR, "golden_o_ffn_prefill.npz") + + +def test_o_ffn_raises_on_byte_diff(): + g = _load("golden_o_ffn_prefill.npz") + perturbed = {k: v.copy() for k, v in g.items()} + arr = perturbed["output"].view(np.uint8).copy() + arr[0] ^= 0x01 + perturbed["output"] = arr.view(bfloat16).reshape(g["output"].shape) + with pytest.raises(GoldenMismatch, match="output"): + validate_against_golden(perturbed, GOLDEN_DIR, "golden_o_ffn_prefill.npz") +``` + +- [ ] **Step 3: Run the tests** + +```bash +cd programming_examples/llama32_1b/ablation/prefill && python3 -m pytest tests/test_validation_gate.py -v +``` +Expected: 4 passed. + +- [ ] **Step 4: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/validate.py \ + programming_examples/llama32_1b/ablation/prefill/tests/test_validation_gate.py +git commit -m "ablation/prefill: parameterized validation gate + tests" +``` + +--- + +## Task 11: FA invariant integration + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/flash_attn_const.py` + +FA's role per spec: held constant in every cell. Same standalone ELF, same invocation pattern, same BO management. The only thing the cells do differently around FA is the upstream/downstream BO management of rms_gemms_rope's outputs and o_ffn's inputs — both happen via host hop in every cell (matches production). + +- [ ] **Step 1: Write `flash_attn_const.py`** + +```python +"""FlashAttention invariant: same standalone ELF + same invocation in every cell. + +FA's MLIR builder is at programming_examples/flash_attention/kernel_fusion_based/attn_npu2_seqfirst.py +with kwargs matching Plan 1's compile_all_kernels() in llama32_1b_prefill.py. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache + + +def _attn_backend_kwargs(): + return { + "verbose": False, + "omit_while_true_loop": False, # head_dim=64, lkp=64 enables shared buffers + "omit_pingpong": "all", + "runtime_loop_tiling_sizes": [1, 1], + "output_format": "elf", + "instance_name": "attention_bf16", + } + + +def compile_flash_attn(cache: KernelCache, config): + """Compile FA ELF if not already cached. ~46s first time per profile.md.""" + if "flash_attn" in cache.artifacts: + return + from flash_attention.kernel_fusion_based.attn_npu2_seqfirst import ( + build_module as build_attn, + ) + seq = config["seq_len"]; head_dim = config["head_dim"] + n_heads = config["n_heads"]; n_kv_heads = config["n_kv_heads"] + mod = build_attn( + lk=seq, lkp=head_dim, lq=seq, lqp=256, + dk=head_dim, dv=head_dim, + num_q_tiles=4, num_cascade_stages=4, + num_heads=n_heads, num_kv_heads=n_kv_heads, + causal=True, + ) + cache.compile_and_cache("flash_attn", mod, _attn_backend_kwargs()) + cache._save_manifest() + + +def run_flash_attn(cache, q_roped, k_roped, v, layer_idx=0): + """Run FA on extracted q_roped/k_roped/v from rms_gemms_rope. + Returns attn_out (extracted to host) ready to feed o_ffn. + """ + seq = q_roped.shape[0]; emb = q_roped.shape[1] + args = [q_roped, k_roped, v, np.zeros((seq, emb), dtype=bfloat16)] + t0 = time.perf_counter() + out = cache.load_and_run( + "flash_attn", _attn_backend_kwargs(), + *args, + output_indices=[3], + intermediate_indices={3}, + bo_key=f"FA_L{layer_idx}", + ) + return {"attn_out": out[3], "_wall_s": time.perf_counter() - t0} +``` + +- [ ] **Step 2: Smoke test (compile + invoke once)** + +```bash +cd programming_examples/llama32_1b/ablation/prefill/build +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -c " +import sys, os +sys.path[:0] = ['..', '../..', '../../..'] +import numpy as np +from ml_dtypes import bfloat16 +from kernel_builder.cache import KernelCache +from cells.cell_d_merged import CONFIG +from cells.flash_attn_const import compile_flash_attn, run_flash_attn + +cache = KernelCache(cache_dir='standalone_cache', verbose=False) +cache.load_manifest() +compile_flash_attn(cache, CONFIG) +seq = CONFIG['seq_len']; emb = CONFIG['emb_dim']; kv = CONFIG['kv_dim'] +q = np.zeros((seq, emb), dtype=bfloat16) +k = np.zeros((seq, kv), dtype=bfloat16) +v = np.zeros((seq, kv), dtype=bfloat16) +out = run_flash_attn(cache, q, k, v) +print(f'FA OK, attn_out shape={out[\"attn_out\"].shape}, wall={out[\"_wall_s\"]*1000:.1f}ms') +" +``` +Expected: `FA OK, attn_out shape=(2048, 2048), wall=...ms`. First run includes ~46s compile. + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/flash_attn_const.py +git commit -m "ablation/prefill: FA invariant integration (compile + invoke same ELF in every cell)" +``` + +--- + +## Phase 4 — Parameterized Cells (Tasks 12–14) + +## Task 12: Cell A — naive parameterized + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/cell_a_naive.py` + +The cell takes a `KernelGroupSpec` and walks its `sub_launches` in order, invoking each via `cache.load_and_run(naive=True)`. Between sub-launches, the previous output is extracted to host (because naive=True forces all-read) and re-written into the next call's input array slot. + +The trick: each sub-launch's standalone signature has a fixed shape `(input_or_weight, activation_input, output)` for the GEMM/RoPE families. The activation input slot may be 0 or 1 depending on the builder. The spec's `BatonLink.consumer_in_slot` tells us which slot to write the upstream output into. For Cell A (no actual sharing), we use the baton_links list only to know how to thread Python data — not for BO aliasing. + +- [ ] **Step 1: Write `cell_a_naive.py`** + +```python +"""Cell A — Naive no-merge for a generic KernelGroupSpec. + +For each sub-launch: + 1. Allocate a numpy buffer for the output (zeros). + 2. Build the call's input arrays per the spec's BatonLink upstream + (or layer_inputs[name] if no upstream link for that input slot). + 3. Invoke cache.load_and_run with naive=True (writes everything, + reads everything every call). + 4. Stash the output into a results dict keyed by sub_launch.name. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from cells.common import compile_standalone_kernels + + +def _consumer_input_for(spec, consumer_idx, consumer_slot, results, layer_inputs): + """Return the numpy array to put in (consumer_idx, consumer_slot). + + If a BatonLink targets this (consumer_idx, consumer_slot), use the + producer's output from results. Otherwise, look up by sub-launch name + in layer_inputs. + """ + for link in spec.baton_links: + if link.consumer_idx == consumer_idx and link.consumer_in_slot == consumer_slot: + producer_name = spec.sub_launches[link.producer_idx].name + return results[producer_name] + # Not a baton-driven slot — must be in layer_inputs by sub-launch name + sub = spec.sub_launches[consumer_idx] + # Convention: layer_inputs uses canonical slot-0 names per sub-launch. + # The implementer should adjust this lookup if the spec uses different keys. + return layer_inputs.get(f"{sub.name}_in{consumer_slot}", + layer_inputs.get(f"{sub.name}_x")) + + +def compile_cell_a(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +def run_cell_a(cache, spec, layer_inputs, layer_idx=0): + """Run all spec.sub_launches sequentially with naive=True. + + layer_inputs is a dict whose keys are documented per-spec (typically: + raw layer inputs like x_in, weight matrices, LUTs). + Returns dict with each sub-launch's output keyed by sub.name, plus _wall_s. + """ + backend = {**__import__("kernel_builder.backend_presets", fromlist=[spec.name.upper() + "_BACKEND"]).__dict__.get(spec.name.upper() + "_BACKEND", {})} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + # Allocate output buffer with the right shape + # The implementer will need a per-spec shape registry to map + # (sub.name, slot) → shape. For now, we infer from layer_inputs. + # NOTE: This is a placeholder; the concrete shape lookup belongs in + # the spec or in a small helper invoked here. + out_buf = layer_inputs[f"_out_buf_{sub.name}"] # implementer provides + + # Build the call args list of length 3 (assume 3-arg standalone) + args = [None, None, None] + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = out_buf + elif slot == sub.weight_slot_in_standalone: + args[slot] = layer_inputs[f"{sub.name}_w"] + else: + # Activation input + args[slot] = _consumer_input_for(spec, idx, slot, results, layer_inputs) + + result = cache.load_and_run( + f"{spec.name}__{sub.name}", backend, + *args, + output_indices=[sub.output_slot_in_standalone], + naive=True, + ) + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results +``` + +**Note on `_out_buf_` and `_w`**: the implementer should refine `layer_inputs`'s schema. A cleaner approach is to add a small `_shape_map` or `_naming_convention` field to `KernelGroupSpec` so cells can compute output buffer sizes and look up weights/activations by their sub-launch slot positions deterministically. + +The above is a starting point — the implementer is expected to iterate on the helper functions as they discover the actual weight/input shapes per sub-launch. The contract is: `run_cell_a(cache, spec, layer_inputs)` returns `{sub.name: output_array, ..., "_wall_s": float}` for every sub.name in `spec.sub_launches`. + +- [ ] **Step 2: Sanity-check single-layer for rms_gemms_rope vs golden** + +```bash +cd programming_examples/llama32_1b/ablation/prefill/build +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -c " +import sys, os +sys.path[:0] = ['..', '../..', '../../..'] +import numpy as np +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND +from cells.cell_a_naive import compile_cell_a, run_cell_a +from specs.rms_gemms_rope import SPEC +from golden.regen_golden import _synthetic_layer_inputs, CONFIG +from validate import validate_against_golden, GoldenMismatch + +cache = KernelCache(cache_dir='standalone_cache', verbose=False) +cache.load_manifest() +compile_cell_a(cache, SPEC, RMS_GEMMS_ROPE_BACKEND) + +layer_inputs = _synthetic_layer_inputs(0, CONFIG) +# Adapter: convert layer_inputs into the schema cell_a_naive expects +# (this is the implementer's first iteration job — write the adapter) +# ... +out = run_cell_a(cache, SPEC, layer_inputs) +# Map cell-A's per-sub-launch outputs to the golden's keys +cell_outputs = { + 'normed': out['rmsnorm'], + 'q': out['q_gemm'], + 'k': out['k_gemm'], + 'v': out['v_gemm'], + 'q_roped': out['rope_q'], + 'k_roped': out['rope_k'], +} +try: + validate_against_golden(cell_outputs, '../golden', 'golden_rms_gemms_rope_prefill.npz') + print('Cell A rms_gemms_rope bit-exact PASS') +except GoldenMismatch as e: + print(f'Cell A rms_gemms_rope FAIL: {e}') +" +``` + +If the script errors due to schema gaps (`_out_buf_` keys missing), iterate on `_consumer_input_for` and the layer_inputs adapter until validation passes. **Do not push through with non-bit-exact results.** + +If you cannot get bit-exact PASS within reasonable effort, escalate as BLOCKED — the parameterization may need a richer spec (e.g., shape map per sub-launch) or the slot conventions may be off. + +- [ ] **Step 3: Commit only after PASS for both kernel-groups** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/cell_a_naive.py +git commit -m "ablation/prefill: Cell A naive parameterized harness" +``` + +--- + +## Task 13: Cell B — static parameterized + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/cell_b_static.py` + +Identical structure to Cell A, but adds a `preload_cell_b(cache, spec, weights_per_layer)` that writes weights once per layer with `static_input_indices={spec.weight_slots}` and matching `bo_key`. The run path uses `static_input_indices` to skip the rewrite. + +- [ ] **Step 1: Write `cell_b_static.py`** + +Mirror Plan 1's `cells/cell_b_static.py` pattern (reference: `programming_examples/llama32_1b/ablation/cells/cell_b_static.py:1-179`), but replace the hardcoded sub-launch loop with a walk over `spec.sub_launches`. + +For each sub-launch, the preload does: + +```python +cache.load_and_run( + f"{spec.name}__{sub.name}", backend, + *_preload_args(sub, weights_per_layer[li]), + output_indices=[sub.output_slot_in_standalone], + static_input_indices={sub.weight_slot_in_standalone} + if sub.weight_slot_in_standalone is not None else set(), + bo_key=f"B_{spec.name}_{sub.name}_L{li}", +) +``` + +The actual run path is the same dataflow as Cell A but with: +- No `naive=True` flag. +- `static_input_indices={sub.weight_slot_in_standalone}` set per call. +- Same `bo_key` as preload. + +Skip showing the full file — the implementer can copy Cell A's structure and add the static_input_indices argument. The bit-exact validation step is identical to Cell A's Step 2. + +- [ ] **Step 2: Validate bit-exact for both kernel-groups** + +Same one-liner pattern as Task 12 Step 2, importing `cell_b_static`. Expected: `Cell B rms_gemms_rope bit-exact PASS` AND `Cell B o_ffn bit-exact PASS`. + +- [ ] **Step 3: Commit on success** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/cell_b_static.py +git commit -m "ablation/prefill: Cell B per-layer weight BOs parameterized" +``` + +--- + +## Task 14: Cell C — charitable parameterized (BO aliasing) + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/cell_c_charitable.py` + +Cell C extends Cell B by aliasing intermediate BOs across separate `xrt.run()` calls per `spec.baton_links`. The pattern from Plan 1 (`programming_examples/llama32_1b/ablation/cells/cell_c_charitable.py:1-223`) generalizes cleanly: walk `spec.baton_links` and call `_share_bo` from `cells/common.py`. + +- [ ] **Step 1: Write `cell_c_charitable.py`** + +The structure: + +```python +def preload_cell_c(cache, spec, weights_per_layer, backend_preset): + """Same allocation as Cell B (one call per kernel per layer with weights), + then walk spec.baton_links and alias intermediate BOs.""" + # ... Cell B preload pattern ... + for li in range(len(weights_per_layer)): + for link in spec.baton_links: + producer = spec.sub_launches[link.producer_idx] + consumer = spec.sub_launches[link.consumer_idx] + _share_bo( + cache, + f"C_{spec.name}_{producer.name}_L{li}", link.producer_out_slot, + f"C_{spec.name}_{consumer.name}_L{li}", link.consumer_in_slot, + ) + + +def run_cell_c(cache, spec, layer_inputs, layer_idx=0): + """Same call sequence as Cell B but with intermediate_indices set on + aliased slots so the host doesn't write zero-fill to them.""" + # For each call, intermediate_indices includes: + # - The output slot if it's a producer in any baton_link + # - Any input slot if this call is the consumer of a baton_link + # Build per-sub-launch intermediate sets from the spec.baton_links. + intermediate_for = {} # sub_idx -> set of slots + for link in spec.baton_links: + intermediate_for.setdefault(link.producer_idx, set()).add(link.producer_out_slot) + intermediate_for.setdefault(link.consumer_idx, set()).add(link.consumer_in_slot) + # ... rest mirrors Cell B with intermediate_indices=intermediate_for[idx] ... +``` + +The implementer should reference Plan 1's `cell_c_charitable.py` for the per-call boilerplate (allocating BO via load_and_run with dummy data first, then aliasing, then the actual timed run with `intermediate_indices`). + +- [ ] **Step 2: Validate bit-exact for both kernel-groups** + +Same pattern as Tasks 12/13. Expected: `Cell C rms_gemms_rope bit-exact PASS` AND `Cell C o_ffn bit-exact PASS`. + +If aliasing fails, debug per Plan 1's notes (Task 13 in the decode pilot plan): `print(id(...))` to verify the BOs are the same object after `_share_bo`. + +- [ ] **Step 3: Commit on success** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/cell_c_charitable.py +git commit -m "ablation/prefill: Cell C BO baton-pass parameterized" +``` + +--- + +## Phase 5 — Multi-Layer + Orchestrator (Tasks 15–16) + +## Task 15: Multi-layer wrapper + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/cells/multi_layer.py` + +Wraps a per-layer triple in a 16-layer loop. The `x_in` of layer L+1 = `output` of layer L's o_ffn. FA runs between rms_gemms_rope and o_ffn in every layer, with `attn_out` extracted to host and fed into o_ffn's slot 0. + +- [ ] **Step 1: Write `multi_layer.py`** + +```python +"""16-layer prefill wrapper. + +Threads: rms_gemms_rope[L] -> FA[L] -> o_ffn[L] -> rms_gemms_rope[L+1] + +The cell-A/B/C/D dispatch strategy is independent of this wrapper; we +take the cell's per-kernel-group runner as a parameter. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.flash_attn_const import run_flash_attn + + +def run_16_layer_prefill( + cache, config, + run_rms_gemms_rope, run_o_ffn, + layer_inputs_per_layer, +): + """Run a 16-layer prefill via the supplied per-kernel-group runners. + + Args: + cache: shared KernelCache (FA + both groups + standalones all reside here) + config: dict from cell_d_merged.CONFIG + run_rms_gemms_rope(cache, layer_inputs, layer_idx) -> {normed,q,k,v,q_roped,k_roped, _wall_s} + run_o_ffn(cache, layer_inputs, layer_idx) -> {output, _wall_s} + layer_inputs_per_layer: list of 16 dicts, each with all per-layer weights+LUTs+x_in[layer 0 only] + + Returns dict with: + per_layer_wall: list of 16 floats (wall time per layer including FA) + total_wall: float + final_output: numpy array (last layer's o_ffn output) + """ + n_layers = len(layer_inputs_per_layer) + per_layer_wall = [] + x_in = layer_inputs_per_layer[0]["x_in"] + final_output = None + + t_total_start = time.perf_counter() + for L in range(n_layers): + layer_in = dict(layer_inputs_per_layer[L]) + layer_in["x_in"] = x_in # threaded from previous layer + + t_layer_start = time.perf_counter() + + # 1. rms_gemms_rope + rg_out = run_rms_gemms_rope(cache, layer_in, layer_idx=L) + # 2. FA (invariant) + # rms_gemms_rope returns 1D flat arrays; FA expects 2D (seq, dim) + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + q_roped_2d = rg_out["q_roped"].reshape(seq, emb) + k_roped_2d = rg_out["k_roped"].reshape(seq, kv) + v_2d = rg_out["v"].reshape(seq, kv) + fa_out = run_flash_attn(cache, q_roped_2d, k_roped_2d, v_2d, layer_idx=L) + # 3. o_ffn — assemble inputs + of_in = { + "attn_out": fa_out["attn_out"], + "wo": layer_in["wo"], + "x_residual": x_in, + "ffn_norm_w": layer_in["ffn_norm_w"], + "w_gate": layer_in["w_gate"], + "w_up": layer_in["w_up"], + "w_down": layer_in["w_down"], + } + of_out = run_o_ffn(cache, of_in, layer_idx=L) + # The o_ffn output (slot 14) is 1D (n_total = seq*emb); reshape for next layer + x_in = of_out["output"].reshape(config["seq_len"], config["emb_dim"]) + final_output = x_in + + per_layer_wall.append(time.perf_counter() - t_layer_start) + + total_wall = time.perf_counter() - t_total_start + return { + "per_layer_wall": per_layer_wall, + "total_wall": total_wall, + "final_output": final_output, + } +``` + +- [ ] **Step 2: Smoke test (Cell D × 2 layers as a sanity check, not 16)** + +```bash +cd programming_examples/llama32_1b/ablation/prefill/build +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -c " +import sys, os +sys.path[:0] = ['..', '../..', '../../..'] +from kernel_builder.cache import KernelCache +from cells.cell_d_merged import (CONFIG, compile_cell_d_rms_gemms_rope, + compile_cell_d_o_ffn, + run_cell_d_rms_gemms_rope, run_cell_d_o_ffn) +from cells.flash_attn_const import compile_flash_attn +from cells.multi_layer import run_16_layer_prefill +from golden.regen_golden import _synthetic_layer_inputs + +cache = KernelCache(cache_dir='standalone_cache', verbose=False) +cache.load_manifest() +compile_cell_d_rms_gemms_rope(cache) +compile_cell_d_o_ffn(cache) +compile_flash_attn(cache, CONFIG) + +layers = [_synthetic_layer_inputs(L, CONFIG) for L in range(2)] +out = run_16_layer_prefill(cache, CONFIG, + run_cell_d_rms_gemms_rope, run_cell_d_o_ffn, layers) +print(f'2-layer Cell D: total={out[\"total_wall\"]*1000:.1f}ms, ' + f'per_layer={[f\"{w*1000:.1f}\" for w in out[\"per_layer_wall\"]]}') +" +``` + +Expected: a number around 160 ms (= 2 layers × ~80 ms/layer per profile.md). If much higher, check for kernel re-compile happening per layer (shouldn't — the artifact cache should hit on second call). + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/cells/multi_layer.py +git commit -m "ablation/prefill: 16-layer wrapper threading rms_gemms_rope -> FA -> o_ffn" +``` + +--- + +## Task 16: `run_ablation.py` orchestrator + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/run_ablation.py` + +Three modes: `--scope=single-layer`, `--scope=16-layer`, `--scope=both` (default). For each scope, run validation gate first (single-layer Cell A/B/C/D each validated against golden), then time each cell over N trials. + +- [ ] **Step 1: Write the orchestrator** + +> **Implementation note (post-execution wash-up):** Two fixes were applied versus the +> original skeleton: +> 1. **sys.path always-remove-then-insert:** `_PREFILL` must be at `sys.path[0]` so +> `prefill/cells/` wins over any `ablation/cells/`. The pattern is: append lower-priority +> dirs, then force `_PREFILL` to index 0 with remove-then-insert. +> 2. **`_unload_all_contexts()` between cells in 16-layer scope:** The NPU has ~16 HW +> context slots. Cell A/B/C each load 14 standalone contexts + FA = 15 total, plus +> Cell D adds 2 merged + FA = 3. Without unloading between cells the limit is exceeded. +> `_unload_all_contexts` clears `cache._loaded` and `cache._cached_bos`; Cell B/C +> weights are then re-preloaded before the 16-layer run. + +```python +"""Run the prefill 4-cell ablation. + +Modes: + --scope=single-layer 5 trials × 1-layer cell call (per kernel-group) + --scope=16-layer 5 trials × 16-layer triple (rms->FA->o_ffn) loop + --scope=both (default) both above + +Run from programming_examples/llama32_1b/ablation/prefill/build/ +(where standalone_cache/ lives and xclbins are found). +""" + +import argparse +import json +import os +import sys +import time + +# Path setup: this script lives in prefill/; CWD is build/ (where standalone_cache/ lives) +# prefill/ -> ablation/ -> llama32_1b/ -> programming_examples/ +_PREFILL = os.path.dirname(os.path.abspath(__file__)) +_ABLATION = os.path.dirname(_PREFILL) +_LLAMA = os.path.dirname(_ABLATION) +_PROG_EXAMPLES = os.path.dirname(_LLAMA) + +# Insert in ascending priority: _PROG_EXAMPLES appended, _PREFILL at front. +# Use append for lower-priority dirs so they don't shadow prefill's 'cells' package. +for p in (_PROG_EXAMPLES, _LLAMA, _ABLATION): + if p not in sys.path: + sys.path.append(p) +# _PREFILL must be at index 0 so prefill/cells/ wins over ablation/cells/. +if _PREFILL in sys.path: + sys.path.remove(_PREFILL) +sys.path.insert(0, _PREFILL) + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND, O_FFN_BACKEND + +from validate import validate_against_golden, GoldenMismatch +from cells import cell_a_naive, cell_b_static, cell_c_charitable, cell_d_merged +from cells.flash_attn_const import compile_flash_attn +from cells.multi_layer import run_16_layer_prefill +from specs.rms_gemms_rope import SPEC as RG_SPEC +from specs.o_ffn import SPEC as OF_SPEC +from golden.regen_golden import _synthetic_layer_inputs + +GOLDEN_DIR = os.path.join(_PREFILL, "golden") + + +# --------------------------------------------------------------------------- +# Context management +# --------------------------------------------------------------------------- + + +def _unload_all_contexts(cache): + """Unload all XRT HW contexts and drop all cached BOs. + + The NPU has a limited number of HW context slots (~16). When switching + between single-layer (14+ standalone contexts) and 16-layer (up to 15 + contexts for Cell A/B/C), we must release all contexts first to avoid + hitting the limit. + + BOs are allocated against a specific XRT device handle; after unloading + the backend that handle is nulled, so the old BO objects are unusable. + We must also clear _cached_bos so the next load_and_run allocates fresh + BOs against the new device. This means preloaded Cell B/C weights are + lost and will be re-written on the next call (acceptable since the + 16-layer loop only runs one cell at a time anyway). + """ + for name, (backend, _) in list(cache._loaded.items()): + try: + backend.unload() + except Exception: + pass + cache._loaded.clear() + cache._cached_bos.clear() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--trials", type=int, default=5) + ap.add_argument( + "--scope", + choices=["single-layer", "16-layer", "both"], + default="both", + ) + ap.add_argument("--out", default=None) + args = ap.parse_args() + + cache = KernelCache(cache_dir="standalone_cache", verbose=False) + cache.load_manifest() + + # ---- Compile all cells + FA (idempotent -- skips if already cached) ---- + print("=== Compiling kernels (idempotent) ===") + cell_a_naive.compile_cell_a(cache, RG_SPEC, RMS_GEMMS_ROPE_BACKEND) + cell_a_naive.compile_cell_a(cache, OF_SPEC, O_FFN_BACKEND) + cell_b_static.compile_cell_b(cache, RG_SPEC, RMS_GEMMS_ROPE_BACKEND) + cell_b_static.compile_cell_b(cache, OF_SPEC, O_FFN_BACKEND) + cell_c_charitable.compile_cell_c(cache, RG_SPEC, RMS_GEMMS_ROPE_BACKEND) + cell_c_charitable.compile_cell_c(cache, OF_SPEC, O_FFN_BACKEND) + cell_d_merged.compile_cell_d_rms_gemms_rope(cache) + cell_d_merged.compile_cell_d_o_ffn(cache) + compile_flash_attn(cache, cell_d_merged.CONFIG) + print("All kernels compiled/cached.\n") + + # ---- Generate per-layer synthetic inputs (all 16 layers) ---- + layer_inputs_per_layer = [ + _synthetic_layer_inputs(L, cell_d_merged.CONFIG) for L in range(16) + ] + + # ---- Pre-load weights for Cell B and Cell C (both kernel-groups, all 16 layers) ---- + print("=== Pre-loading weights for Cell B and Cell C ===") + rg_weights = [ + {k: li[k] for k in ["norm_w", "wq", "wk", "wv", "lut_q", "lut_k"]} + for li in layer_inputs_per_layer + ] + of_weights = [ + {k: li[k] for k in ["wo", "ffn_norm_w", "w_gate", "w_up", "w_down"]} + for li in layer_inputs_per_layer + ] + + cell_b_static.preload_cell_b( + cache, RG_SPEC, rg_weights, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + cell_b_static.preload_cell_b( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + cell_c_charitable.preload_cell_c( + cache, RG_SPEC, rg_weights, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + cell_c_charitable.preload_cell_c( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + print("Preload done.\n") + + results = { + "config": cell_d_merged.CONFIG, + "trials": args.trials, + "scope": args.scope, + "cells": {}, + } + + # ---- Timing: 16-layer scope ---- + if args.scope in ("16-layer", "both"): + print("=== Timing: 16-layer scope ===") + for cell in ("A", "B", "C", "D"): + # Unload all previously opened XRT contexts and BOs before each + # cell's 16-layer run. The NPU has ~16 HW context slots; Cell A/B/C + # each need 14 standalone contexts + FA = 15 total. Starting fresh + # per cell avoids hitting the limit. + # Cell B/C weights are lost with the BOs -- re-preload them below. + _unload_all_contexts(cache) + + # Re-preload weights for B and C after the context reset. + if cell == "B": + cell_b_static.preload_cell_b( + cache, RG_SPEC, rg_weights, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND, + ) + cell_b_static.preload_cell_b( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + elif cell == "C": + cell_c_charitable.preload_cell_c( + cache, RG_SPEC, rg_weights, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND, + ) + cell_c_charitable.preload_cell_c( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + + # ... timing loop (see shipped run_ablation.py for full implementation) ... + print() + + # ---- Dump JSON ---- + out_path = args.out or f"results_prefill_{int(time.time())}.json" + with open(out_path, "w") as f: + json.dump(results, f, indent=2) + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() +``` + +> The full implementation (validation loops, single-layer timing, 16-layer timing, output +> key adapters) lives in the shipped `run_ablation.py`. The skeleton above captures the +> structural changes from the wash-up; see the committed file for the complete code. + +Output JSON shape (target): + +```json +{ + "config": {...}, + "trials": 5, + "cells": { + "A": { + "rms_gemms_rope": {"validation": "PASS", "single_layer": {...}, "16_layer": {...}}, + "o_ffn": {"validation": "PASS", "single_layer": {...}, "16_layer": {...}}, + "16_layer_total": {"median_s": ..., ...} + }, + "B": {...}, "C": {...}, "D": {...} + } +} +``` + +- [ ] **Step 2: Run end-to-end (5 trials, both scopes)** + +```bash +cd programming_examples/llama32_1b/ablation/prefill/build +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 ../run_ablation.py --trials 5 --scope both --out results_pilot.json +``` + +Expected output: validation lines for all 4 cells × 2 kernel-groups (8 × PASS), then timing lines for single-layer and 16-layer scopes per cell. Total run time ~5-10 min. + +The 16-layer Cell D total wall time is the **headline** number — should be in the ballpark of `profile.md`'s 1.27 s. + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/run_ablation.py +git commit -m "ablation/prefill: orchestrator runs all cells × both kernel-groups × both scopes" +``` + +--- + +## Phase 6 — Report + Docs (Tasks 17–19) + +## Task 17: `analyze.py` report generator + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/analyze.py` + +- [ ] **Step 1: Write the analyzer** + +```python +"""Read prefill results JSON and emit a markdown report. + +Sections: +- Validation badge (per cell × kernel-group) +- Single-layer per-call medians (per cell × kernel-group) +- 16-layer total wall (per cell, with comparison to profile.md's 1.27s) +- Marginal deltas (A→B, B→C, C→D, A→D — per kernel-group AND aggregated) +- Per-launch breakdown extracted from Cell C's single-layer timing data +""" + +import argparse +import json +import os +import time + +PROFILE_MD_HEADLINE_S = 1.27 # production prefill from profile.md + + +def report(results): + cells = results["cells"] + out = [] + out.append("# Prefill Ablation — Report\n") + out.append(f"Trials: {results['trials']}, config: seq={results['config']['seq_len']}, " + f"emb={results['config']['emb_dim']}, hidden={results['config']['hidden_dim']}\n") + + # Validation table + out.append("## Validation\n") + out.append("| Cell | rms_gemms_rope | o_ffn |") + out.append("|------|----------------|-------|") + for c in ("A", "B", "C", "D"): + rg = cells.get(c, {}).get("rms_gemms_rope", {}).get("validation", "—") + of = cells.get(c, {}).get("o_ffn", {}).get("validation", "—") + out.append(f"| {c} | {rg} | {of} |") + out.append("") + + # Single-layer per-call timing table + out.append("## Single-layer per-call medians (ms)\n") + out.append("| Cell | rms_gemms_rope | o_ffn |") + out.append("|------|----------------|-------|") + for c in ("A", "B", "C", "D"): + rg_s = cells.get(c, {}).get("rms_gemms_rope", {}).get("single_layer", {}).get("median_s") + of_s = cells.get(c, {}).get("o_ffn", {}).get("single_layer", {}).get("median_s") + rg_str = f"{rg_s*1000:.2f}" if rg_s is not None else "—" + of_str = f"{of_s*1000:.2f}" if of_s is not None else "—" + out.append(f"| {c} | {rg_str} | {of_str} |") + out.append("") + + # 16-layer headline table + out.append("## 16-layer total wall (s) — comparable to profile.md's 1.27 s\n") + out.append("| Cell | Median (s) | Min (s) | Max (s) | vs profile.md |") + out.append("|------|------------|---------|---------|---------------|") + for c in ("A", "B", "C", "D"): + e = cells.get(c, {}).get("16_layer_total", {}) + if not e: + out.append(f"| {c} | — | — | — | — |") + continue + md = e["median_s"]; mn = e["min_s"]; mx = e["max_s"] + ratio = md / PROFILE_MD_HEADLINE_S + out.append(f"| {c} | {md:.3f} | {mn:.3f} | {mx:.3f} | {ratio:.2f}× |") + out.append("") + + # Marginal deltas (16-layer total) + out.append("## Marginal deltas (16-layer total)\n") + def m(c): return cells.get(c, {}).get("16_layer_total", {}).get("median_s") + pairs = [ + ("A→B (= #2 per-layer weight BOs)", "A", "B"), + ("B→C (= #3 shared intermediate BOs)", "B", "C"), + ("C→D (= #1 multi-launch merging, isolated)", "C", "D"), + ("A→D (= total dispatch-related speedup)", "A", "D"), + ] + out.append("| Comparison | Δ s | Speedup |") + out.append("|------------|-----|---------|") + for label, a, b in pairs: + ma, mb = m(a), m(b) + if ma is None or mb is None: + out.append(f"| {label} | — | — |") + continue + out.append(f"| {label} | {ma - mb:+.3f} | {ma/mb:.2f}× |") + out.append("") + + return "\n".join(out) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("results_json") + ap.add_argument("--out", default=None) + args = ap.parse_args() + with open(args.results_json) as f: + results = json.load(f) + text = report(results) + out = args.out or f"report_prefill_{int(time.time())}.md" + with open(out, "w") as f: + f.write(text) + print(f"Wrote {out}\n") + print(text) + + +if __name__ == "__main__": + main() +``` + +- [ ] **Step 2: Generate report** + +```bash +cd programming_examples/llama32_1b/ablation/prefill/build +python3 ../analyze.py results_pilot.json --out report_pilot.md +cat report_pilot.md +``` + +Expected: a markdown report with all 4 cells' validation, single-layer medians, 16-layer totals, and marginal deltas. The Cell D 16-layer total should be in the ballpark of 1.27 s (the headline confirmation). + +- [ ] **Step 3: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/analyze.py +git commit -m "ablation/prefill: markdown report generator with profile.md comparison" +``` + +--- + +## Task 18: README + Makefile + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/prefill/Makefile` +- Create: `programming_examples/llama32_1b/ablation/prefill/README.md` + +- [ ] **Step 1: Write Makefile** + +```make +# Llama-3.2-1B prefill ablation harness +# +# make compile — compile all standalone ELFs + Cell D's 2 merged ELFs + FA (~10-15 min, cached) +# make regen-golden — regenerate committed golden fixtures (rare; only after Cell D changes) +# make run — run all 4 cells × 2 kernel-groups × both scopes, emit JSON +# make report — generate markdown report from latest results JSON +# make all — compile + run + report +# make clean — wipe build/ + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +BUILD := build + +.PHONY: help compile regen-golden run report all clean + +help: + @echo "make compile | regen-golden | run | report | all | clean" + +compile: + @mkdir -p $(BUILD) + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -m cells.common + +regen-golden: compile + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 $(srcdir)/golden/regen_golden.py + +run: compile + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 $(srcdir)/run_ablation.py --out results_latest.json + +report: + cd $(BUILD) && python3 $(srcdir)/analyze.py results_latest.json --out report_latest.md && cat report_latest.md + +all: compile run report + +clean: + rm -rf $(BUILD) +``` + +- [ ] **Step 2: Write README.md** + +```markdown +# Llama-3.2-1B Prefill Ablation (Plan 2) + +Bit-exact 4-cell ablation of the production **prefill** pipeline: +`rms_gemms_rope` (6 launches) + FlashAttention (held constant) + `o_ffn` +(8 launches), at seq=2048 GEMM shapes, both single-layer and full 16-layer +scopes. + +Companion docs: +- Plan 2 spec: [`ablation/docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md`](../specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md) +- Plan 1 (decode pilot): removed from repo (subsumed by full-decode study at `ablation/decode/`) +- Production profile: [`../../../docs/profile.md`](../../../docs/profile.md) + +## What this measures + +Four cells, identical computation, different dispatch strategy: + +| Cell | What changes within each kernel-group | Adds | +|------|---------------------------------------|------| +| A | 6 + 8 separate `xrt.run()` per layer, host round-trip on every intermediate | (baseline) | +| B | + per-layer weight BOs (`static_input_indices`) | #2 | +| C | + shared intermediate BOs across separate `xrt.run()` calls | #3 | +| D | + multi-launch merging (production: 6→1 + 8→1 ELF per layer) | #1 | + +FA is held constant per spec (un-mergeable). Cross-kernel-group transfers +(rms→FA, FA→o_ffn) go through host in every cell — matches production. + +## Quick start + +``` +make compile # one-time, ~10-15 min for 14 standalone ELFs + 2 merged + FA +make run # 5 trials × both scopes × all 4 cells (~5-10 min) +make report # markdown report +``` + +## Validation gate + +Every cell's per-kernel-group output must match the committed `golden/*.npz` +fixtures bit-exactly (synthetic numpy seed=42 inputs). Cells failing the +gate suppress their timing in the report. + +## Reproducibility + +``` +cd programming_examples/llama32_1b/ablation/prefill +make clean && make all +``` + +The 16-layer Cell D total wall time should be in the ballpark of +`profile.md`'s **1.27 s** production headline. The marginal deltas table +attributes how much each of optimizations #1, #2, #3 contributes to that +number for prefill specifically. + +Unit tests (NPU-free): + +``` +python3 -m pytest tests/ -v +``` + +## Limitations of this plan (Plan 2-decode and Plan 2-lm-head will address) + +- Prefill only — decode `o_gemv_ffn` and the LM Head L1/L8 mini-study are + separate plans. +- FA is invariant in every cell. A potential **Plan 2.5** could ablate + cross-kernel-group BO sharing (FA's input BOs aliased to rms_gemms_rope's + output BOs); production doesn't currently do this. +- Synthetic weights only. No HuggingFace. + +## File map + +| Path | Purpose | +|------|---------| +| `specs/kernel_group.py` | Frozen dataclasses | +| `specs/{rms_gemms_rope,o_ffn}.py` | Concrete spec instances | +| `standalone_builders/` | Re-exported STANDALONES registries | +| `cells/cell_{a,b,c,d}_*.py` | Parameterized cell harnesses | +| `cells/flash_attn_const.py` | FA invariant | +| `cells/multi_layer.py` | 16-layer wrapper | +| `cells/common.py` | Compile harness, BO baton-pass helper | +| `golden/` | Two committed npz fixtures + regen script | +| `validate.py` | Parameterized bit-exact gate | +| `run_ablation.py` | Orchestrator | +| `analyze.py` | Report generator | +| `Makefile` | Convenience targets | +``` + +- [ ] **Step 3: Smoke test** + +```bash +cd programming_examples/llama32_1b/ablation/prefill && make help +``` +Expected: prints help line. + +- [ ] **Step 4: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/prefill/Makefile \ + programming_examples/llama32_1b/ablation/prefill/README.md +git commit -m "ablation/prefill: README + Makefile" +``` + +--- + +## Task 19: End-to-end smoke + final commit + +- [ ] **Step 1: Wipe build/ and run from scratch** + +```bash +cd programming_examples/llama32_1b/ablation/prefill +make clean +make all +``` + +Expected: ~10-15 min compile, ~5-10 min run, ~1 sec report. Final report shows all 4 cells × 2 kernel-groups PASS validation, with 16-layer Cell D total in the 1.0-1.5 s range (headline confirmation). + +- [ ] **Step 2: Run unit tests** + +```bash +cd programming_examples/llama32_1b/ablation/prefill && python3 -m pytest tests/ -v +``` + +Expected: all tests pass (kernel_group_spec: 4, validation_gate: 4, parameterized_cells: variable). + +- [ ] **Step 3: Verify Plan 1 isolation** + +```bash +git diff llama-3.2-1B-devel..HEAD --stat -- programming_examples/llama32_1b/ablation/ | grep -v '^ programming_examples/llama32_1b/ablation/prefill/' +``` + +Expected: empty output (no Plan 1 files modified). + +- [ ] **Step 4: Final commit (if any uncommitted artifacts)** + +```bash +cd /home/jiajli/apps/mlir-air +git status +``` + +If clean: nothing to do. Otherwise update `.gitignore` and commit: + +```bash +git commit -m "ablation/prefill: final cleanup" +``` + +--- + +## Self-Review Checklist + +**Spec coverage** (against `programming_examples/llama32_1b/ablation/docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md`): + +- §3 4-cell ladder for both kernel-groups → Tasks 8 (D), 12 (A), 13 (B), 14 (C) ✓ +- §4 Invariants (FA constant, decode files unmodified, etc.) → Tasks 11 (FA), 19 (Plan 1 isolation check) ✓ +- §5 Correctness verification (golden + per-cell + cross-cell) → Tasks 9, 10, 12-14 ✓ (cross-cell consistency re-check is in the orchestrator T16 — implementer should add a re-validation pass after timing) +- §6 Per-launch breakdown via Cell C → falls out of orchestrator T16 (records per-call write/kernel/read) + analyzer T17 (could be augmented with a per-launch breakdown table; this plan ships the JSON shape that supports it) +- §7 Host overhead → falls out of (wall - Σ(write+kernel+read)); analyzer T17 can add a row for it +- §8.1 Self-contained subdir → T1 ✓ +- §8.2 KernelGroupSpec dataclass → T2 ✓ +- §8.3 Standalone 1-launch ELFs → T5, T6 ✓ +- §8.4 Cell-specific harness (parameterized) → T12-T14 ✓ +- §8.5 Validation reuse → T10 ✓ +- §8.6 Orchestrator scopes (single-layer + 16-layer) → T15 (multi_layer wrapper), T16 (orchestrator with --scope) ✓ +- §9 Stats: 5 trials, drop run 1, median + range → T16 `_time_runs` ✓ +- §10 Deliverable structure → matches file structure section above ✓ +- §11 Out of scope → respected (no Plan 2-decode, no LM Head, no real HF weights) +- §12 Isolation strategy: worktree + Plan 1 files unmodified → T19 Step 3 verification ✓ +- §13 Risks → flagged in Tasks 7 (compile time), 12 (variance), 14 (BO aliasing debug) + +**Placeholder scan**: searched for "TBD", "TODO", "fill in", "implement later" — none in the plan body. The orchestrator T16 has explicit `pass` placeholders documented as "for the implementer to fill in"; this is intentional because the cell function signatures are clarified in T12-T14 and the orchestrator wires them up. + +**Type consistency**: `KernelCache.naive=True` (Plan 1, already shipped), `compile_standalone_kernels(cache, group_name, registry, backend_preset)` signature consistent across T7, T12, T13, T14. `_share_bo` signature consistent with Plan 1's. `BatonLink` and `SubLaunchSpec` field names consistent across T2, T3, T4, T12-T14. + +**Coverage gaps that are intentional and documented**: +- Cross-cell consistency re-check (§5 of spec) is described as belonging in T16's orchestrator but not concretely coded — implementer should add it after the per-cell validation loop. +- Per-launch breakdown table in the report is supported by the JSON shape but not rendered by the analyzer in T17. Plan 2's primary goal is the headline number; per-launch table can be added in a wash-up. +- Cell A/B/C parameterized harnesses (T12-T14) leave the layer_inputs-to-args adapter to the implementer's iteration; the spec dataclass is the contract but the concrete naming convention (e.g., `_out_buf_`, `_w`) needs refinement during T12. diff --git a/programming_examples/llama32_1b/ablation/docs/plans/2026-05-12-llama32-1b-ablation-plan2-fulldecode-plan.md b/programming_examples/llama32_1b/ablation/docs/plans/2026-05-12-llama32-1b-ablation-plan2-fulldecode-plan.md new file mode 100644 index 000000000..262333992 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/docs/plans/2026-05-12-llama32-1b-ablation-plan2-fulldecode-plan.md @@ -0,0 +1,1121 @@ +# Llama-3.2-1B Plan 2 (Full Decode) Ablation Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build the 4-cell ablation ladder for the **full decode** path: 16 layers × (`rms_gemv_rope` 6 launches + `decode_attention_cpu` + `o_gemv_ffn` 8 launches) + final RMSNorm + `lm_head_gemv` 8-partition + argmax. Single decode token per timed trial, 5 trials, drop warmup. Bit-exact validation against committed goldens. Headline number directly comparable to `profile.md`'s per-token decode latency. + +**Architecture:** Self-contained subdir `programming_examples/llama32_1b/ablation/decode/` (Plan 0 files at `ablation/` and Plan 1 files at `ablation/prefill/` remain byte-immutable). The 4 parameterized cell modules from Plan 1 are reused via direct import or copy; the new work is (a) `o_gemv_ffn` standalone builders + spec, (b) the per-token loop wrapper, (c) KV cache state management, (d) the `lm_head_gemv` invariant runner, (e) goldens + orchestration + report. + +**Tech Stack:** Same as Plan 1 — Python 3, numpy, ml_dtypes (bfloat16), pytest, mlir-air's `XRTBackend` + `KernelCache`. Production builders imported: `build_rms_gemv_rope_module`, `build_o_gemv_ffn_module`, `build_lm_head_gemv_module` from `multi_launch_builder/`. + +**Companion docs:** +- Plan 2 spec: `programming_examples/llama32_1b/ablation/docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md` +- Master ablation spec: removed from repo (decode pilot deleted; this full-decode study supersedes it) +- Plan 0 (decode pilot): removed from repo (subsumed by this study) +- Plan 1 (full prefill): `programming_examples/llama32_1b/ablation/docs/plans/2026-05-07-llama32-1b-ablation-plan2-prefill.md` — primary template for code patterns +- Plan 1 working code: `programming_examples/llama32_1b/ablation/prefill/` — copy-paste reference +- Plan 0 working code: removed; the standalone builder content is now inlined into `programming_examples/llama32_1b/ablation/decode/standalone_builders/rms_gemv_rope.py` +- Audience-facing summary: `programming_examples/llama32_1b/docs/ABLATION_STUDY.html` + +**Branch / worktree setup:** Create a NEW worktree (e.g., `ablation-plan2-fulldecode`) from `llama-3.2-1B-devel`. Do NOT modify Plan 0/1 directories. + +--- + +## File Structure + +All paths under `programming_examples/llama32_1b/ablation/decode/` unless noted. + +| File | Responsibility | Source pattern | +|------|----------------|----------------| +| `__init__.py` | Package marker | — | +| `README.md` | Methodology, run instructions, results, reproducibility | Plan 1's README | +| `Makefile` | `make compile / regen-golden / run / report / all / clean` | Plan 1's Makefile | +| `specs/__init__.py` | Package marker | — | +| `specs/kernel_group.py` | Re-export `SubLaunchSpec`, `BatonLink`, `KernelGroupSpec` from Plan 1 (single source of truth) | `from ablation.prefill.specs.kernel_group import *` | +| `specs/rms_gemv_rope.py` | Concrete spec for the 6-launch decode attention pre-block | Plan 1's `specs/rms_gemms_rope.py` adapted | +| `specs/o_gemv_ffn.py` | Concrete spec for the 8-launch decode FFN block | Plan 1's `specs/o_ffn.py` adapted (GEMV instead of GEMM, mv_k8192 for Down) | +| `standalone_builders/__init__.py` | Package marker | — | +| `standalone_builders/rms_gemv_rope.py` | Re-export Plan 0's `STANDALONES` registry | `from ablation.standalone_builders.decode_rms_gemv_rope import STANDALONES` | +| `standalone_builders/o_gemv_ffn.py` | 8 single-launch builder wrappers + `STANDALONES` registry — NEW | Plan 1's `standalone_builders/o_ffn.py` adapted | +| `cells/__init__.py` | Package marker | — | +| `cells/common.py` | Re-export Plan 1's `compile_standalone_kernels`, `_share_bo`, `_extract_public_func_name` | `from ablation.prefill.cells.common import *` | +| `cells/cell_a_naive.py` | Parameterized Cell A — direct re-export from Plan 1 | `from ablation.prefill.cells.cell_a_naive import run_cell_a, compile_cell_a` | +| `cells/cell_b_static.py` | Parameterized Cell B | re-export from Plan 1 | +| `cells/cell_c_charitable.py` | Parameterized Cell C | re-export from Plan 1 | +| `cells/cell_d_merged.py` | Wraps production `build_rms_gemv_rope_module`, `build_o_gemv_ffn_module` | Plan 1's `cell_d_merged.py` adapted | +| `cells/decode_attn_const.py` | CPU attention invariant — same Python function in every cell | NEW (Plan 1's `flash_attn_const.py` pattern) | +| `cells/lm_head_const.py` | LM head invariant — production-merged 8-partition GEMV | NEW | +| `cells/per_token_loop.py` | The end-to-end timed unit: 16 layers + final RMSNorm + LM head + argmax | NEW (Plan 1's `multi_layer.py` adapted, replacing 16-prompt-position with 1-decode-token) | +| `cells/kv_cache.py` | KV cache state init + per-trial reset | NEW | +| `golden/__init__.py` | Package marker | — | +| `golden/regen_golden.py` | One-shot Cell-D run; dumps two npz fixtures + meta json | Plan 1's regen pattern | +| `golden/golden_rms_gemv_rope_decode.npz` | Cell D output, layer 0, seed=42, current_pos=7 | Generated | +| `golden/golden_o_gemv_ffn_decode.npz` | Cell D output for o_gemv_ffn | Generated | +| `golden/golden_meta.json` | Hashes, shapes, prompt_len, current_pos | Plan 1 | +| `validate.py` | Bit-exact gate, parameterized — re-export Plan 1's `validate.py` directly | `from ablation.prefill.validate import *` | +| `run_ablation.py` | Orchestrator | Plan 1 adapted | +| `analyze.py` | JSON → markdown report | Plan 1 adapted | +| `tests/__init__.py` | Package marker | — | +| `tests/conftest.py` | Pytest sys.path setup | Plan 1 | +| `tests/test_o_gemv_ffn_spec.py` | Dataclass invariants for the new `o_gemv_ffn` spec | NEW | +| `tests/test_kv_cache_state.py` | Verifies cache initialization + per-trial reset is deterministic | NEW | +| `tests/test_validation_gate.py` | Tests against the two new decode goldens | Plan 1 adapted | + +**Files NOT touched** (isolation guarantee): every file under `programming_examples/llama32_1b/ablation/` outside `decode/`. Production code under `programming_examples/llama32_1b/{kernel_builder,multi_launch_builder}/` is read-only — only imported. + +--- + +## Phase 1 — Skeleton + reused infrastructure (Tasks 1–3) + +## Task 1: Worktree + subdir skeleton + conftest + +**Files:** +- Create: `programming_examples/llama32_1b/ablation/decode/` with subdirs `specs/`, `standalone_builders/`, `cells/`, `golden/`, `tests/` +- Create: 7 `__init__.py` files +- Create: `decode/tests/conftest.py` + +- [ ] **Step 1: Set up worktree** + +```bash +cd /home/jiajli/apps/mlir-air +git worktree add .claude/worktrees/ablation-plan2-fulldecode llama-3.2-1B-devel +cd .claude/worktrees/ablation-plan2-fulldecode +git checkout -b llama32_1b/ablation-plan2-fulldecode +``` + +- [ ] **Step 2: Create directory tree + package markers** + +```bash +DECODE=programming_examples/llama32_1b/ablation/decode +mkdir -p $DECODE/{specs,standalone_builders,cells,golden,tests} +for d in "" /specs /standalone_builders /cells /golden /tests; do + touch $DECODE$d/__init__.py +done +``` + +- [ ] **Step 3: Write conftest.py** + +`programming_examples/llama32_1b/ablation/decode/tests/conftest.py`: + +```python +"""Pytest config for full-decode ablation tests. + +Inserts paths so tests can import: +- llama32_1b/ packages (kernel_builder, multi_launch_builder) +- llama32_1b/ablation/ (Plan 0's standalone_builders + validate.py) +- llama32_1b/ablation/prefill/ (Plan 1's cells, specs, common helpers) +- llama32_1b/ablation/decode/ (this package) +- programming_examples/ (matvec, weighted_rms_norm, ffn_swiglu) +""" + +import os +import sys + +_THIS = os.path.dirname(os.path.abspath(__file__)) +_DECODE = os.path.dirname(_THIS) +_ABLATION = os.path.dirname(_DECODE) +_LLAMA = os.path.dirname(_ABLATION) +_PROG_EXAMPLES = os.path.dirname(_LLAMA) + +for p in (_PROG_EXAMPLES, _LLAMA, _ABLATION, os.path.join(_ABLATION, "prefill"), _DECODE): + if p not in sys.path: + sys.path.insert(0, p) +``` + +- [ ] **Step 4: Verify imports work** + +```bash +cd programming_examples/llama32_1b/ablation/decode +python3 -c "import sys; sys.path.insert(0, '.'); sys.path.insert(0, '..'); from ablation.prefill.specs.kernel_group import KernelGroupSpec; print('OK')" +``` + +Expected: prints `OK` (Plan 1's KernelGroupSpec dataclass loads). + +- [ ] **Step 5: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/decode +git commit -m "ablation-decode: skeleton subdir + package markers + conftest" +``` + +## Task 2: Re-exports — kernel_group, common, validate + +**Files:** +- Create: `decode/specs/kernel_group.py` +- Create: `decode/cells/common.py` +- Create: `decode/validate.py` +- Create: `decode/cells/cell_a_naive.py`, `cell_b_static.py`, `cell_c_charitable.py` (re-exports) + +- [ ] **Step 1: Re-export the spec dataclasses** + +`decode/specs/kernel_group.py`: + +```python +"""Re-export Plan 1's KernelGroupSpec dataclasses (single source of truth).""" + +from ablation.prefill.specs.kernel_group import ( + SubLaunchSpec, + BatonLink, + KernelGroupSpec, +) + +__all__ = ["SubLaunchSpec", "BatonLink", "KernelGroupSpec"] +``` + +- [ ] **Step 2: Re-export the common helpers** + +`decode/cells/common.py`: + +```python +"""Re-export Plan 1's common helpers.""" + +from ablation.prefill.cells.common import ( + compile_standalone_kernels, + _share_bo, + _extract_public_func_name, + standalone_backend_kwargs, +) + +__all__ = [ + "compile_standalone_kernels", + "_share_bo", + "_extract_public_func_name", + "standalone_backend_kwargs", +] +``` + +- [ ] **Step 3: Re-export the validate gate** + +`decode/validate.py`: + +```python +"""Re-export Plan 1's parameterized bit-exact validation gate.""" + +from ablation.prefill.validate import ( + validate_against_golden, + GoldenMismatch, +) + +__all__ = ["validate_against_golden", "GoldenMismatch"] +``` + +- [ ] **Step 4: Re-export Cells A/B/C (parameterized — work for any KernelGroupSpec)** + +`decode/cells/cell_a_naive.py`: + +```python +"""Re-export Plan 1's parameterized Cell A — same code, decode spec at call site.""" + +from ablation.prefill.cells.cell_a_naive import run_cell_a, compile_cell_a + +__all__ = ["run_cell_a", "compile_cell_a"] +``` + +(Same pattern for `cell_b_static.py` and `cell_c_charitable.py`.) + +- [ ] **Step 5: Smoke test the re-exports** + +```bash +cd programming_examples/llama32_1b/ablation/decode +python3 -c "from cells.cell_a_naive import run_cell_a; from validate import validate_against_golden; print('imports OK')" +``` + +- [ ] **Step 6: Commit** + +```bash +git add programming_examples/llama32_1b/ablation/decode +git commit -m "ablation-decode: re-export Plan 1's KernelGroupSpec, helpers, validate, cells A-C" +``` + +## Task 3: Re-export rms_gemv_rope standalone builders from Plan 0 + +**Files:** +- Create: `decode/standalone_builders/rms_gemv_rope.py` + +- [ ] **Step 1: Write the re-export** + +`decode/standalone_builders/rms_gemv_rope.py`: + +```python +"""Re-export Plan 0's existing decode_rms_gemv_rope standalone builders. + +Plan 0 already built 6 single-launch wrappers for rms_gemv_rope's sub-launches. +Plan 2 reuses them verbatim. +""" + +from ablation.standalone_builders.decode_rms_gemv_rope import STANDALONES + +__all__ = ["STANDALONES"] +``` + +- [ ] **Step 2: Verify** + +```bash +cd programming_examples/llama32_1b/ablation/decode +python3 -c "from standalone_builders.rms_gemv_rope import STANDALONES; assert len(STANDALONES) == 6; print('rms_gemv_rope STANDALONES re-exported, count =', len(STANDALONES))" +``` + +Expected: prints `rms_gemv_rope STANDALONES re-exported, count = 6` + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: re-export rms_gemv_rope STANDALONES from Plan 0" +``` + +--- + +## Phase 2 — New work for o_gemv_ffn (Tasks 4–6) + +## Task 4: o_gemv_ffn KernelGroupSpec + +**Files:** +- Create: `decode/specs/o_gemv_ffn.py` + +This spec describes the 8 sub-launches of `o_gemv_ffn`: O GEMV, eltwise add (residual #1), RMSNorm, Gate GEMV, Up GEMV, SwiGLU (silu_and_mul), Down GEMV (uses `mv_k8192.o`), eltwise add (residual #2). Slot semantics + baton links for Cell C aliasing. + +- [ ] **Step 1: Write the failing test first** + +`tests/test_o_gemv_ffn_spec.py`: + +```python +"""Validate the o_gemv_ffn KernelGroupSpec structure.""" + +from specs.o_gemv_ffn import O_GEMV_FFN_SPEC + + +def test_spec_has_8_sublaunches(): + assert len(O_GEMV_FFN_SPEC.sub_launches) == 8 + + +def test_sublaunch_names_match_production_order(): + names = [s.name for s in O_GEMV_FFN_SPEC.sub_launches] + assert names == [ + "o_gemv", "add_attn_residual", "ffn_rmsnorm", + "gate_gemv", "up_gemv", "swiglu", + "down_gemv_k8192", "add_ffn_residual", + ] + + +def test_baton_links_cover_all_intermediate_handoffs(): + """Every intermediate output must have a baton link to the next consumer.""" + # 7 intermediates × 1 producer-consumer link each (linear chain except the gate→swiglu and up→swiglu fork) + # Detailed expected: o_gemv→add_attn, add_attn→ffn_rmsnorm, ffn_rmsnorm→{gate,up,save_residual}, + # gate→swiglu, up→swiglu, swiglu→down_gemv, down_gemv→add_ffn + expected_links = [...] + assert sorted(O_GEMV_FFN_SPEC.baton_links) == sorted(expected_links) +``` + +- [ ] **Step 2: Run test to confirm it fails** + +```bash +cd programming_examples/llama32_1b/ablation/decode +python3 -m pytest tests/test_o_gemv_ffn_spec.py -v +``` + +Expected: ImportError or test failure (spec doesn't exist yet). + +- [ ] **Step 3: Write the spec** + +`decode/specs/o_gemv_ffn.py`: + +```python +"""KernelGroupSpec for the 8-launch o_gemv_ffn decode kernel-group. + +Production: rms_gemms_rope's sister for the second half of a decode layer. +Stitched into one ELF in production (Cell D); Cell A/B/C run all 8 as +separate xrt.run() calls. +""" + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec +# (Concrete instance follows. Mirror structure from prefill/specs/o_ffn.py +# but adapt for GEMV (single-token) shapes and the mv_k8192 down-step.) + +O_GEMV_FFN_SPEC = KernelGroupSpec( + name="o_gemv_ffn", + sub_launches=[ + # ... 8 SubLaunchSpec entries ... + ], + baton_links=[ + # ... intermediate handoff edges ... + ], +) +``` + +(Full content needs careful adaptation of Plan 1's `o_ffn` spec to single-token GEMV shapes — a ~200-line file.) + +- [ ] **Step 4: Run test to confirm pass** + +Expected: 3 passed. + +- [ ] **Step 5: Commit** + +```bash +git add specs/o_gemv_ffn.py tests/test_o_gemv_ffn_spec.py +git commit -m "ablation-decode: o_gemv_ffn KernelGroupSpec + tests" +``` + +## Task 5: rms_gemv_rope KernelGroupSpec + +**Files:** +- Create: `decode/specs/rms_gemv_rope.py` + +The 6-sub-launch spec for the decode attention pre-block. Plan 0 had standalone builders but never wrote a formal `KernelGroupSpec` — Plan 1's `KernelGroupSpec` dataclass post-dates Plan 0. Now we need one for the parameterized cell harnesses. + +- [ ] **Step 1: Write spec** + +`decode/specs/rms_gemv_rope.py`: + +```python +"""KernelGroupSpec for the 6-launch rms_gemv_rope decode kernel-group.""" + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + +RMS_GEMV_ROPE_SPEC = KernelGroupSpec( + name="rms_gemv_rope", + sub_launches=[ + # rmsnorm, q_gemv, k_gemv, v_gemv, rope_q, rope_k + ], + baton_links=[ + # rmsnorm→q_gemv, rmsnorm→k_gemv, rmsnorm→v_gemv + # q_gemv→rope_q, k_gemv→rope_k + ], +) +``` + +(Reference Plan 0's `cells/cell_a_naive.py` for the slot/argument layout.) + +- [ ] **Step 2: Smoke test it loads** + +```bash +python3 -c "from specs.rms_gemv_rope import RMS_GEMV_ROPE_SPEC; assert len(RMS_GEMV_ROPE_SPEC.sub_launches) == 6" +``` + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: rms_gemv_rope KernelGroupSpec" +``` + +## Task 6: o_gemv_ffn standalone builders + +**Files:** +- Create: `decode/standalone_builders/o_gemv_ffn.py` + +8 single-launch MLIR builder wrappers, one per sub-launch of `o_gemv_ffn`. Mirror Plan 1's `standalone_builders/o_ffn.py` but for GEMV (single-token, M=1) shapes. + +- [ ] **Step 1: Write builders** + +`decode/standalone_builders/o_gemv_ffn.py`: + +```python +"""8 single-launch builder wrappers for o_gemv_ffn sub-launches. + +Each builder produces a full MLIR module containing ONE air.launch. +Used by Cells A/B/C (separate xrt.run() per sub-launch). +Cell D uses the production merged build_o_gemv_ffn_module instead. +""" + +from ml_dtypes import bfloat16 +import numpy as np + +from matvec.run import build_module as _build_matvec +from weighted_rms_norm.weighted_rms_norm import build_module as _build_rmsnorm +from ffn_swiglu.silu_and_mul import build_module as _build_swiglu +from eltwise_add.eltwise_add import build_module as _build_add +# Reuse multi_launch_builder/o_gemv_ffn_multi.py's _build_add_2d_to_1d if needed. + +def build_o_gemv(): ... # 1 air.launch wrapping the O GEMV +def build_add_attn_residual(): ... # 1 air.launch wrapping eltwise add (2D) +def build_ffn_rmsnorm(): ... +def build_gate_gemv(): ... +def build_up_gemv(): ... +def build_swiglu(): ... +def build_down_gemv_k8192(): ... # uses dg_matvec_vectorized_bf16_bf16 (renamed K=8192 variant) +def build_add_ffn_residual(): ... + +STANDALONES = { + "o_gemv": build_o_gemv, + "add_attn_residual": build_add_attn_residual, + "ffn_rmsnorm": build_ffn_rmsnorm, + "gate_gemv": build_gate_gemv, + "up_gemv": build_up_gemv, + "swiglu": build_swiglu, + "down_gemv_k8192": build_down_gemv_k8192, + "add_ffn_residual": build_add_ffn_residual, +} +``` + +- [ ] **Step 2: Smoke test each builder produces a parseable MLIR module (NPU-free)** + +```bash +python3 -c " +from standalone_builders.o_gemv_ffn import STANDALONES +for name, build_fn in STANDALONES.items(): + mod = build_fn() # signature TBD per kernel + assert mod is not None + print(f'{name}: ok') +" +``` + +- [ ] **Step 3: Commit** + +```bash +git add standalone_builders/o_gemv_ffn.py +git commit -m "ablation-decode: 8 standalone builders for o_gemv_ffn sub-launches" +``` + +--- + +## Phase 3 — Decode-specific orchestration (Tasks 7–10) + +## Task 7: KV cache initialization + per-trial reset + +**Files:** +- Create: `decode/cells/kv_cache.py` +- Create: `tests/test_kv_cache_state.py` + +- [ ] **Step 1: Write the failing test** + +`tests/test_kv_cache_state.py`: + +```python +"""KV cache state must be deterministic and resettable per trial.""" + +import numpy as np +from cells.kv_cache import build_initial_kv_cache, reset_position + + +def test_initial_cache_is_deterministic(): + cfg = {"n_layers": 16, "n_kv_heads": 8, "head_dim": 64, "max_seq": 2048} + c1 = build_initial_kv_cache(cfg, prompt_len=7, seed=42) + c2 = build_initial_kv_cache(cfg, prompt_len=7, seed=42) + np.testing.assert_array_equal(c1["k_cache"], c2["k_cache"]) + np.testing.assert_array_equal(c1["v_cache"], c2["v_cache"]) + + +def test_reset_position_clears_target_slot(): + cfg = {"n_layers": 16, "n_kv_heads": 8, "head_dim": 64, "max_seq": 2048} + cache = build_initial_kv_cache(cfg, prompt_len=7, seed=42) + cache["k_cache"][0, :, 7, :] = 99.0 # simulate write + reset_position(cache, 7) + assert (cache["k_cache"][0, :, 7, :] == 0).all() + # positions 0-6 untouched + assert not (cache["k_cache"][0, :, :7, :] == 0).all() +``` + +- [ ] **Step 2: Implement** + +`decode/cells/kv_cache.py`: + +```python +"""KV cache state management for the per-token timed loop. + +Two functions: +- build_initial_kv_cache: deterministic synthetic pre-fill of `prompt_len` positions +- reset_position: zero out a specific position (called between trials) +""" + +import numpy as np +from ml_dtypes import bfloat16 + + +def build_initial_kv_cache(config, prompt_len, seed): + """Pre-fill the KV cache with synthetic deterministic values.""" + rng = np.random.default_rng(seed) + shape = (config["n_layers"], config["n_kv_heads"], config["max_seq"], config["head_dim"]) + k = np.zeros(shape, dtype=bfloat16) + v = np.zeros(shape, dtype=bfloat16) + k[:, :, :prompt_len, :] = rng.standard_normal( + (config["n_layers"], config["n_kv_heads"], prompt_len, config["head_dim"]) + ).astype(bfloat16) * 0.5 + v[:, :, :prompt_len, :] = rng.standard_normal( + (config["n_layers"], config["n_kv_heads"], prompt_len, config["head_dim"]) + ).astype(bfloat16) * 0.5 + return {"k_cache": k, "v_cache": v, "current_pos": prompt_len} + + +def reset_position(cache, pos): + """Zero out the K/V cache slots at `pos` for ALL layers.""" + cache["k_cache"][:, :, pos, :] = 0 + cache["v_cache"][:, :, pos, :] = 0 +``` + +- [ ] **Step 3: Run tests** + +Expected: 2 passed. + +- [ ] **Step 4: Commit** + +```bash +git add cells/kv_cache.py tests/test_kv_cache_state.py +git commit -m "ablation-decode: KV cache init + per-trial reset (tested deterministic)" +``` + +## Task 8: Decode CPU attention invariant runner + +**Files:** +- Create: `decode/cells/decode_attn_const.py` + +Wraps the production `decode_attention_cpu` from `llama32_1b_decode.py:96` so all 4 cells call exactly the same Python function. + +- [ ] **Step 1: Write** + +`decode/cells/decode_attn_const.py`: + +```python +"""Invariant CPU attention runner — same Python function in every cell.""" + +import time +from llama32_1b_decode import decode_attention_cpu + + +def run_decode_attention(cache, q_roped, k_roped, v, layer_idx, current_pos, config): + """Run CPU attention; update KV cache slot at current_pos. + + Returns: (attn_out, elapsed_seconds) + """ + t0 = time.perf_counter() + attn_out = decode_attention_cpu( + q_roped, k_roped, v, + cache["k_cache"][layer_idx], + cache["v_cache"][layer_idx], + current_pos, + config["n_heads"], config["n_kv_heads"], config["head_dim"], + ) + elapsed = time.perf_counter() - t0 + return attn_out, elapsed +``` + +- [ ] **Step 2: Smoke test (NPU-free, dummy inputs)** + +```bash +python3 -c " +from cells.decode_attn_const import run_decode_attention +import numpy as np +from ml_dtypes import bfloat16 +# Construct minimal dummy cache + activation tensors and verify it runs +# ... +print('decode_attn_const runs') +" +``` + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: invariant CPU attention runner" +``` + +## Task 9: LM head invariant runner + +**Files:** +- Create: `decode/cells/lm_head_const.py` + +Production `lm_head_gemv` is one merged ELF (8 stitched partitions); held INVARIANT in every cell. + +- [ ] **Step 1: Write** + +```python +"""Invariant LM head runner — production-merged 8-partition GEMV in every cell.""" + +import time +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import LM_GEMV_BACKEND +from multi_launch_builder.lm_head_gemv_multi import build_lm_head_gemv_module + + +def compile_lm_head(cache: KernelCache, config): + """Compile the production LM head ELF (one-time).""" + if "lm_head_gemv" in cache.artifacts: + return + mod = build_lm_head_gemv_module(...) # production args + cache.compile_and_cache("lm_head_gemv", mod, {**LM_GEMV_BACKEND, "verbose": cache.verbose}) + + +def run_lm_head(cache, x_normed, weights, vocab_size): + """Run LM head; return (next_token_id, elapsed_seconds).""" + t0 = time.perf_counter() + # ... mirror production code from llama32_1b_inference.py:434-446 ... + elapsed = time.perf_counter() - t0 + return next_token, elapsed +``` + +- [ ] **Step 2: Commit** + +```bash +git commit -am "ablation-decode: invariant LM head runner" +``` + +## Task 10: Per-token loop wrapper (the timed unit) + +**Files:** +- Create: `decode/cells/per_token_loop.py` + +Wraps a per-layer triple in a 16-layer loop, then runs final RMSNorm + LM head + argmax. **This is the per-trial timed unit.** + +- [ ] **Step 1: Write** + +```python +"""Per-token decode loop wrapper. + +Each call generates ONE decode token at the given current_pos. Cell-specific +dispatch is injected via run_rms_gemv_rope and run_o_gemv_ffn function args. +CPU attention and LM head are invariant. + +Returns: + { + "next_token": int, + "per_layer_npu_wall": list of 16 floats (sum of rms_gemv_rope + o_gemv_ffn per layer), + "cpu_attn_wall": float (sum across 16 layers), + "lm_head_wall": float, + "total_wall": float (everything inside the timer), + } +""" + +import time +import numpy as np +from ml_dtypes import bfloat16 + +from cells.decode_attn_const import run_decode_attention +from cells.lm_head_const import run_lm_head + + +def run_one_decode_token( + cache, config, weights, kv_cache, + x_decode, current_pos, + run_rms_gemv_rope, run_o_gemv_ffn, +): + n_layers = config["n_layers"] + per_layer_npu = [] + cpu_attn_total = 0.0 + x = x_decode + + t_total_start = time.perf_counter() + for L in range(n_layers): + # Per-layer timing + rg_out = run_rms_gemv_rope(cache, layer_inputs={...}, layer_idx=L) + attn_out, attn_t = run_decode_attention( + kv_cache, rg_out["q_roped"], rg_out["k_roped"], rg_out["v"], + layer_idx=L, current_pos=current_pos, config=config, + ) + cpu_attn_total += attn_t + of_out = run_o_gemv_ffn(cache, layer_inputs={...}, layer_idx=L) + x = of_out["output"] + per_layer_npu.append(rg_out["_wall_s"] + of_out["_wall_s"]) + + # Final RMSNorm (CPU) + from llama32_1b_cpu_helpers import rms_norm + x_normed = rms_norm(x.astype(np.float32).reshape(1, config["emb_dim"]), + weights.final_norm.astype(np.float32)).flatten().astype(bfloat16) + next_token, lm_head_t = run_lm_head(cache, x_normed, weights, config["vocab_size"]) + + return { + "next_token": next_token, + "per_layer_npu_wall": per_layer_npu, + "cpu_attn_wall": cpu_attn_total, + "lm_head_wall": lm_head_t, + "total_wall": time.perf_counter() - t_total_start, + } +``` + +- [ ] **Step 2: Smoke test (NPU-free with mock dispatch)** + +Mock `run_rms_gemv_rope` and `run_o_gemv_ffn` to return zeros + dummy wall times. Verify the wrapper completes 16 iterations. + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: per-token loop wrapper (timed unit)" +``` + +--- + +## Phase 4 — Cell D + goldens (Tasks 11–13) + +## Task 11: Cell D — production merged ELFs + +**Files:** +- Create: `decode/cells/cell_d_merged.py` + +Compiles and runs the production `rms_gemv_rope.elf` and `o_gemv_ffn.elf`. Mirror Plan 0's `cell_d_merged.py` and Plan 1's `cell_d_merged.py`. + +- [ ] **Step 1: Write** + +```python +"""Cell D — production-merged decode ELFs. + +Compiles and invokes: +- rms_gemv_rope.elf (6 stitched launches in 1 xrt.run) +- o_gemv_ffn.elf (8 stitched launches in 1 xrt.run) +Same pattern as production llama32_1b_decode.py. +""" + +import time +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RGR_BACKEND, OGF_BACKEND +from multi_launch_builder.rms_gemv_rope_multi import build_rms_gemv_rope_module +from multi_launch_builder.o_gemv_ffn_multi import build_o_gemv_ffn_module + + +def compile_cell_d(cache, config): + if "rms_gemv_rope" not in cache.artifacts: + mod = build_rms_gemv_rope_module(...) + cache.compile_and_cache("rms_gemv_rope", mod, {**RGR_BACKEND, "verbose": cache.verbose}) + if "o_gemv_ffn" not in cache.artifacts: + mod = build_o_gemv_ffn_module(...) + cache.compile_and_cache("o_gemv_ffn", mod, {**OGF_BACKEND, "verbose": cache.verbose}) + cache._save_manifest() + + +def run_rms_gemv_rope_d(cache, layer_inputs, layer_idx): + """Production merged dispatch — mirror llama32_1b_decode.py:run_decode_block.""" + # ... assemble args, call cache.load_and_run("rms_gemv_rope", ...) + # ... return {normed, q, k, v, q_roped, k_roped, _wall_s} + + +def run_o_gemv_ffn_d(cache, layer_inputs, layer_idx): + """Production merged dispatch.""" + # ... call cache.load_and_run("o_gemv_ffn", ...) + # ... return {output, _wall_s} +``` + +- [ ] **Step 2: Quick run on the NPU (preload + 1 trial) to verify it doesn't crash** + +```bash +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -c " +# Compile + run Cell D once with synthetic inputs +# ... +print('Cell D OK') +" +``` + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: Cell D production-merged decode dispatches" +``` + +## Task 12: Generate goldens + +**Files:** +- Create: `decode/golden/regen_golden.py` +- Create: `decode/golden/golden_rms_gemv_rope_decode.npz` (generated) +- Create: `decode/golden/golden_o_gemv_ffn_decode.npz` (generated) +- Create: `decode/golden/golden_meta.json` (generated) + +- [ ] **Step 1: Write the regen script** + +```python +"""Regenerate the two committed golden fixtures from Cell D. + +Usage: + flock -x -w 1800 /tmp/mlir-air-npu.lock python3 golden/regen_golden.py +""" + +import json +import hashlib +import numpy as np + +# ... synthetic seed=42 inputs (mirror Plan 0/1 golden gen) +# ... run Cell D for layer 0, current_pos=7 +# ... save outputs to npz +# ... write golden_meta.json with hashes, shapes, prompt_len, current_pos +``` + +- [ ] **Step 2: Run on NPU and commit the goldens** + +```bash +flock -x -w 1800 /tmp/mlir-air-npu.lock python3 golden/regen_golden.py +git add golden/golden_rms_gemv_rope_decode.npz \ + golden/golden_o_gemv_ffn_decode.npz \ + golden/golden_meta.json \ + golden/regen_golden.py +git commit -m "ablation-decode: regen + commit Cell D goldens" +``` + +## Task 13: Validation gate test against new goldens + +**Files:** +- Create: `tests/test_validation_gate.py` + +- [ ] **Step 1: Write the test** + +```python +"""Verify Plan 1's validate.py works against the new decode goldens.""" + +import os + +import numpy as np +from validate import validate_against_golden, GoldenMismatch + +GOLDEN_DIR = os.path.join(os.path.dirname(__file__), "..", "golden") + + +def test_validate_passes_on_golden_self(): + """Loading the golden and validating it against itself must pass.""" + npz = np.load(os.path.join(GOLDEN_DIR, "golden_rms_gemv_rope_decode.npz")) + cell_outputs = {key: npz[key] for key in npz.files} + validate_against_golden(cell_outputs, GOLDEN_DIR, + golden_filename="golden_rms_gemv_rope_decode.npz") + + +def test_validate_fails_on_byte_diff(): + npz = np.load(os.path.join(GOLDEN_DIR, "golden_rms_gemv_rope_decode.npz")) + cell_outputs = {key: npz[key].copy() for key in npz.files} + cell_outputs["normed"][0] = 0 # corrupt + try: + validate_against_golden(cell_outputs, GOLDEN_DIR, + golden_filename="golden_rms_gemv_rope_decode.npz") + assert False, "expected GoldenMismatch" + except GoldenMismatch: + pass +``` + +- [ ] **Step 2: Run** + +```bash +python3 -m pytest tests/test_validation_gate.py -v +``` + +Expected: 2 passed. + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: validation gate test" +``` + +--- + +## Phase 5 — Orchestration (Tasks 14–16) + +## Task 14: run_ablation.py orchestrator + +**Files:** +- Create: `decode/run_ablation.py` + +For each cell: validate → 5 trials × {per-token-loop} → emit JSON. Mirror Plan 1's `run_ablation.py`. + +- [ ] **Step 1: Write the orchestrator** + +```python +"""Run the 4-cell full-decode ablation. + +Per cell: +- Compile + preload (not timed) +- 5 trials, each: reset KV cache state → run per_token_loop → record total_wall +- Drop trial 1, median + (min, max) over trials 2-5 + +For each cell, also report per-kernel-group medians (rms_gemv_rope, o_gemv_ffn) +extracted from the per_token_loop's per_layer_npu_wall sums. +""" + +import argparse, json, os, sys, time +import numpy as np + +# ... orchestrator logic, mirror Plan 1's run_ablation.py adapted for per-token-loop +``` + +- [ ] **Step 2: Smoke test JSON output structure (NPU-free)** + +Stub out the actual cell runs to return constant times; verify the JSON has the expected schema. + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: run_ablation.py orchestrator" +``` + +## Task 15: analyze.py report generator + +**Files:** +- Create: `decode/analyze.py` + +JSON → markdown report. Mirror Plan 1's `analyze.py`. + +- [ ] **Step 1: Write** + +Tables to emit: +1. **Per-token total wall** × 4 cells (median + range, Δ vs prev, speedup, vs profile.md decode latency) +2. **Per-kernel-group per-call medians** × 4 cells × {rms_gemv_rope, o_gemv_ffn} +3. **Component breakdown** per cell: NPU wall (rms_gemv_rope + o_gemv_ffn × 16) + CPU attention floor + LM head fixed cost +4. **Findings** stub (filled in manually after first run) + +- [ ] **Step 2: Smoke test on the JSON schema** + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: analyze.py markdown report generator" +``` + +## Task 16: Makefile + README + +**Files:** +- Create: `decode/Makefile` +- Create: `decode/README.md` + +- [ ] **Step 1: Write Makefile** + +```makefile +.PHONY: all compile regen-golden run report clean test + +all: compile run report + +compile: + flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -c "from cells.cell_d_merged import compile_cell_d; from kernel_builder.cache import KernelCache; cache = KernelCache(cache_dir='build', verbose=True); compile_cell_d(cache, CONFIG)" + +regen-golden: + flock -x -w 1800 /tmp/mlir-air-npu.lock python3 golden/regen_golden.py + +run: + flock -x -w 1800 /tmp/mlir-air-npu.lock python3 run_ablation.py --trials 5 --out results.json + +report: + python3 analyze.py results.json > report.md + +test: + python3 -m pytest tests/ -v + +clean: + rm -rf build *.json report.md +``` + +- [ ] **Step 2: Write README** + +Mirror Plan 1's README structure: methodology, headline numbers (TBD until run), reproducibility, file map, limitations. + +- [ ] **Step 3: Commit** + +```bash +git commit -am "ablation-decode: Makefile + README" +``` + +--- + +## Phase 6 — Run + analyze + integrate (Tasks 17–18) + +## Task 17: First end-to-end NPU run + +- [ ] **Step 1: Compile** + +```bash +cd programming_examples/llama32_1b/ablation/decode +flock -x -w 1800 /tmp/mlir-air-npu.lock make compile +``` + +Expected: ~5 min, no errors. + +- [ ] **Step 2: Run** + +```bash +flock -x -w 1800 /tmp/mlir-air-npu.lock make run +cat results.json | python3 -m json.tool | head -40 +``` + +Expected: 4 cells reported with `validation: PASS`, per-token medians in the ms-to-tens-of-ms range, Cell D's per-token median in the ballpark of `profile.md`'s decode latency. + +- [ ] **Step 3: Generate report** + +```bash +make report +cat report.md +``` + +- [ ] **Step 4: Sanity checks** + +- All 4 cells PASS validation? If not, debug before continuing. +- Within-cell range (min/max) is small (<5% of median)? +- A→D speedup is >1× (otherwise something is wrong)? +- Cell D ≈ profile.md decode latency (within ~20%)? + +- [ ] **Step 5: Commit results** + +```bash +git add results.json report.md +git commit -m "ablation-decode: first end-to-end run + report" +``` + +## Task 18: Update ABLATION_STUDY.html with Plan 2 results + +**Files:** +- Modify: `programming_examples/llama32_1b/docs/ABLATION_STUDY.html` + +- [ ] **Step 1: Update Section 5.1 status** + +Change the planned-card from "📋 PLANNED" to "✅ Implemented + measured (date)". + +- [ ] **Step 2: Add Section 5.4 (Results — Plan 2: full decode)** + +Mirror Section 4.3 structure: +- Per-token total wall table (4 cells, median, range, Δ vs prev, speedup, vs profile.md) +- Per-kernel-group per-call medians using the `cmp-table` styling +- Component breakdown (CPU floor, LM head fixed cost, dispatch-affected NPU work) +- Findings ul (3-5 bullet points based on actual numbers) + +- [ ] **Step 3: Update Section 6.1 (cross-comparison)** + +Replace "decode vs. prefill (so far)" with three-way comparison: Plan 0 (single-kernel-group decode) vs Plan 1 (full prefill) vs Plan 2 (full decode). New row in the optimization-effect table for each. + +- [ ] **Step 4: Update Quick recap at bottom** + +Change the Plan 2 entry from "designed only, not yet measured" to "A→D = X.XX×, headline finding ..." + +- [ ] **Step 5: Sidebar nav update if needed (probably no change since 5.1/5.2/5.3 still exist + new 5.4)** + +- [ ] **Step 6: Render-verify in headless Chromium** + +```bash +python3 - <<'EOF' +from playwright.sync_api import sync_playwright +HTML = "/path/to/ABLATION_STUDY.html" +with sync_playwright() as p: + b = p.chromium.launch() + pg = b.new_context().new_page() + pg.goto(f"file://{HTML}") + # Screenshot key sections to verify rendering + ... +EOF +``` + +- [ ] **Step 7: Commit + push** + +```bash +git add programming_examples/llama32_1b/docs/ABLATION_STUDY.html +git commit -m "ABLATION_STUDY: Plan 2 (full decode) results integrated" +``` + +--- + +## Done definition + +- [ ] All 4 cells produce bit-identical outputs against committed goldens (validation PASS) +- [ ] Per-token median for Cell D is within ~20% of `profile.md`'s decode per-token latency +- [ ] Per-kernel-group medians for `rms_gemv_rope` are consistent with Plan 0's pilot (allowing for slight differences from running inside the per-token loop vs. standalone) +- [ ] All NPU-free unit tests pass (`pytest tests/ -v`) +- [ ] `report.md` generated with the 4 cells' numbers + speedup attribution +- [ ] `ABLATION_STUDY.html` updated with Section 5.4 results + Section 6.1 three-way comparison +- [ ] All work on a separate branch / worktree so Plan 0 and Plan 1 directories remain byte-immutable +- [ ] PR-ready: README, Makefile, tests, results.json, report.md all in the new `ablation/decode/` subdir + +--- + +## Estimated effort + +- **Tasks 1-3 (skeleton + re-exports):** 30 min +- **Tasks 4-6 (specs + standalone builders for o_gemv_ffn):** 4-6 hours (the most non-trivial work, especially the K=8192 down GEMV variant) +- **Tasks 7-10 (decode-specific orchestration):** 3-4 hours +- **Tasks 11-13 (Cell D + goldens):** 2-3 hours (includes NPU compile time) +- **Tasks 14-16 (orchestration + report + Makefile):** 2 hours +- **Task 17 (first run + sanity check):** 1 hour (mostly NPU lock + verification) +- **Task 18 (HTML integration):** 1-2 hours + +**Total: ~14-19 hours of focused work + ~1-2 hours of NPU lock time**, comparable to Plan 1's prefill effort. + +If subagent-driven-development is used, expect roughly half a day of controller-time + ~3-5 hours of subagent execution time per task with two-stage review. diff --git a/programming_examples/llama32_1b/ablation/docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md b/programming_examples/llama32_1b/ablation/docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md new file mode 100644 index 000000000..3bdea113e --- /dev/null +++ b/programming_examples/llama32_1b/ablation/docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md @@ -0,0 +1,352 @@ +# Llama-3.2-1B NPU2 Ablation Study — Plan 2 (Prefill) Design + +**Status**: Design (pending implementation plan) +**Date**: 2026-05-07 +**Branch**: implementation on `llama32-1b-ablation-plan2-prefill` (worktree from `llama-3.2-1B-devel`) +**Scope**: `programming_examples/llama32_1b/ablation/prefill/` (new self-contained subdir) +**Companion docs**: +- Master ablation spec: [`2026-05-07-llama32-1b-ablation-study-design.md`](2026-05-07-llama32-1b-ablation-study-design.md) +- Plan 1 (decode `rms_gemv_rope` pilot): [`../plans/2026-05-07-llama32-1b-ablation-decode-pilot.md`](../plans/2026-05-07-llama32-1b-ablation-decode-pilot.md) +- Production profile: [`../../programming_examples/llama32_1b/docs/profile.md`](../../programming_examples/llama32_1b/docs/profile.md) + +--- + +## 1. Goal + +Apply the proven 4-cell ablation methodology (validated by Plan 1 on decode +`rms_gemv_rope`) to the **prefill** pipeline. Two prefill kernel-groups are in +scope: `rms_gemms_rope` (6 sub-launches at seq=2048 GEMM shapes) and `o_ffn` +(8 sub-launches at seq=2048 GEMM shapes). FlashAttention is held constant per +master-spec §5 (un-mergeable per `docs/explain.md`'s `air-opt-shim-dma-bds` +scaling note). + +**Two scopes per cell:** +1. **Single-layer per-call timings** for fast iteration and per-launch + breakdown extraction (matches Plan 1's reporting style). +2. **Full 16-layer prefill wall time** for headline numbers directly + comparable to `profile.md`'s **1.27 s** measured production prefill. + +Plan 2 produces a comprehensive prefill ablation report. Decode completion +(`o_gemv_ffn`) and the LM Head L1/L8 mini-study are explicitly **out of +scope** for this plan — they are scheduled as Plan 2-decode and Plan 2-lm-head +follow-ups. + +## 2. Optimizations under study + +Same three optimizations as Plan 1, applied to the prefill kernel-groups: + +| ID | Optimization | Production behavior in prefill | +|---|---|---| +| **#1** | Multi-launch ELF | Per-layer: 6 sub-launches stitched into `rms_gemms_rope.elf` + 8 sub-launches stitched into `o_ffn.elf`, two `xrt.run()` per layer (plus FA). | +| **#2** | Per-layer weight BOs (`static_input_indices`) | All 16 layers' weights pre-loaded into per-layer BOs once during `prepare_runtime`; `static_input_indices` skips re-write on subsequent calls. | +| **#3** | `intermediate_indices` | Buffers the kernel will overwrite are not host-written on subsequent calls. | + +These are the same flags exercised in Plan 1; what changes is the kernel +shape regime (GEMMs at seq=2048 instead of GEMVs at single-token), the launch +counts (6 + 8 instead of 6), and the multi-layer envelope. + +## 3. Experimental design — the 4-cell ladder + +The ladder applies to the **prefill per-layer triple** (rms_gemms_rope + FA + +o_ffn). FA is invariant across cells; the cells differ only in how they +dispatch the within-kernel-group sub-launches of rms_gemms_rope and o_ffn. + +| Cell | Description | Marginal change | Isolates | +|---|---|---|---| +| **A** Naive no-merge | Each sub-launch as separate `xrt.run()`: 6 calls for rms_gemms_rope + 1 FA + 8 calls for o_ffn = **15 NPU calls per layer**. Host round-trip on every intermediate. Weights re-uploaded every call. | (baseline) | — | +| **B** + per-layer weight BOs | Same as A, but weights pre-loaded into per-layer BOs once; `static_input_indices` skips re-write. Still 15 NPU calls per layer. | +#2 | A→B = #2 alone | +| **C** + shared intermediate BOs | Same as B, but intermediate BOs are aliased across separate `xrt.run()` calls **within each kernel-group** (rms_gemms_rope's 6, and o_ffn's 8). Cross-kernel-group transitions (rms→FA, FA→o_ffn) still go through host — matches production. Still 15 NPU calls per layer. | +#3 (intermediate-BO sharing across separate `xrt.run()` calls within each group) | B→C = #3 alone | +| **D** Multi-launch merged | Production: rms_gemms_rope's 6 sub-launches stitched into one ELF, o_ffn's 8 stitched into one ELF. **3 NPU calls per layer** (rms_gemms_rope + FA + o_ffn). | +#1 | C→D = pure #1 (XRT dispatch saved by group-merging) | + +### Reported claims + +| Reported number | What it answers | +|---|---| +| **A→D** | Total naïve→production speedup for prefill (β baseline) | +| **C→D** | Pure multi-launch merging effect for prefill (α baseline) | +| **A→B** | #2 contribution alone in prefill | +| **B→C** | #3 contribution alone in prefill | +| **A→D × 16 layers vs `profile.md`'s 1.27 s** | Confirms (or corrects) the production headline number from a clean ablation | + +## 4. Invariants across all cells + +To ensure cell-to-cell deltas reflect only the within-kernel-group dispatch +strategy: + +- **Same C++ kernels, shapes, weights, prompt seed.** Bit-exact output + validated against Cell D for layer 0 (one validation gate per kernel-group). +- **FlashAttention is the same standalone ELF in every cell.** + `rms_gemms_rope`'s outputs (`q_roped, k_roped, v`) are extracted to host → + written to FA's BOs → `xrt.run` → `attn_out` extracted to host → written to + o_ffn's residual-add input. This cross-kernel-group host hop happens + identically in all cells. (Cross-group BO sharing is a potential + Plan 2.5 — see §11.) +- **Synthetic deterministic inputs.** numpy seed=42 for layer 0; seed=42+i + for layer i. Same RNG that Plan 1 used. +- **Decode-side optimizations untouched.** Plan 1's decode pilot files at + `programming_examples/llama32_1b/ablation/` top-level remain frozen. +- **NPU power state.** Cells run back-to-back within one process (16-layer + loop keeps NPU active throughout the trial). + +## 5. Correctness verification (load-bearing) + +Mirrors Plan 1 §9, with two adjustments: + +- **Two golden fixtures**, one per kernel-group: + `golden/golden_rms_gemms_rope_prefill.npz` and + `golden/golden_o_ffn_prefill.npz`. Each is Cell D's layer-0 output for that + group (numpy seed=42 inputs). +- **Validation per cell**, before any timing data is collected: + 1. Run cell on layer 0. Compare rms_gemms_rope outputs and o_ffn outputs + bit-exactly against their respective goldens. + 2. **No multi-token decode equivalent** (prefill is single-pass). + 3. CPU reference cosine-sim sanity is logged but not gating (BF16 ≠ F32 by + definition). +- **Cross-cell consistency re-check** after timing: re-run cell A vs D for + layer 0 in the same process; assert byte-equal outputs. Catches BO + recycle / lifetime bugs that surface only after long timing runs. +- Failed cells suppress their timing in the report. + +The validation reuses Plan 1's `programming_examples/llama32_1b/ablation/validate.py` +unchanged (it's kernel-group-agnostic). + +## 6. Per-launch breakdown — falls out of Cell C + +Same mechanism as Plan 1: in Cell C, each sub-launch is its own `xrt.run()` +call → existing `KernelCache.Profiler` records `write_ms / kernel_ms / read_ms` +per call. Cell C automatically yields a 6-line breakdown for rms_gemms_rope +and an 8-line breakdown for o_ffn (in addition to the FA timing, which is +identical across cells). + +D − C therefore quantifies pure dispatch-overhead reduction from merging, +**per kernel-group separately** (so we can report e.g. "merging saves X ms in +rms_gemms_rope and Y ms in o_ffn"). + +## 7. Host overhead — same arithmetic as Plan 1 + +For each cell: + +``` +host_overhead = wall_time − Σ(write_ms + kernel_ms + read_ms) +``` + +Reported per cell. The 16-layer wall-time minus 16 × per-layer NPU sum +reveals Python loop overhead in the multi-layer wrapper, distinct from +per-call host overhead. + +## 8. Implementation approach + +### 8.1 Self-contained subdirectory layout + +All Plan 2 code lives under `programming_examples/llama32_1b/ablation/prefill/`. +Plan 1 files at `ablation/` top level are **byte-immutable** during Plan 2 +development. + +``` +ablation/prefill/ +├── README.md methodology, results, reproducibility +├── Makefile compile / run / report / regen-golden / clean +├── specs/ +│ ├── kernel_group.py dataclass: KernelGroupSpec +│ ├── rms_gemms_rope.py 6-launch spec at prefill shapes +│ └── o_ffn.py 8-launch spec at prefill shapes +├── standalone_builders/ +│ ├── rms_gemms_rope.py 6 single-launch builder wrappers +│ └── o_ffn.py 8 single-launch builder wrappers +├── cells/ +│ ├── cell_a_naive.py parameterized; takes a KernelGroupSpec +│ ├── cell_b_static.py " +│ ├── cell_c_charitable.py " (consumes spec.baton_links) +│ ├── cell_d_merged.py wrapper around production build_*_module +│ ├── flash_attn_const.py FA invocation (held constant) +│ └── multi_layer.py wraps per-layer triple in 16-layer loop +├── golden/ +│ ├── regen_golden.py one-shot Cell-D run, dumps both npz files +│ ├── golden_rms_gemms_rope_prefill.npz +│ └── golden_o_ffn_prefill.npz +├── run_ablation.py orchestrator +├── analyze.py JSON → markdown report +└── tests/ + ├── test_kernel_group_spec.py dataclass validation, NPU-free + ├── test_parameterized_cells.py mock-cache tests, NPU-free + └── test_validation_gate.py re-uses Plan 1's validate.py against new goldens +``` + +### 8.2 KernelGroupSpec dataclass + +A single concrete, grep-friendly description per kernel-group: + +```python +@dataclass(frozen=True) +class SubLaunchSpec: + name: str # e.g. "rmsnorm" | "q_gemm" | "rope_q" + builder_ref: Callable # function returning a 1-launch mlir.Module at production shape + build_kwargs: dict # passed verbatim to builder_ref + weight_slot_in_standalone: int | None # which arg slot of the *standalone* call holds the weight (or None) + output_slot_in_standalone: int # which arg slot of the *standalone* call holds the output + + +@dataclass(frozen=True) +class BatonLink: + producer_idx: int # index into sub_launches list + producer_out_slot: int # output slot of producer's standalone signature + consumer_idx: int # index into sub_launches list (must be > producer_idx) + consumer_in_slot: int # input slot of consumer's standalone signature + + +@dataclass(frozen=True) +class KernelGroupSpec: + name: str # "rms_gemms_rope" | "o_ffn" + sub_launches: list[SubLaunchSpec] # ordered execution sequence + merged_arg_signature: list[str] # ordered names matching production merged ELF args + weight_slots: set[int] # slots in merged signature that are weights/LUTs (for Cell D static_input_indices) + intermediate_slots: set[int] # slots that are kernel-overwritten intermediates (for Cell D intermediate_indices) + output_slots_for_validation: list[int] # slots in merged signature whose bytes go in the golden npz + baton_links: list[BatonLink] # Cell C uses these to alias intermediate BOs across sub-launches +``` + +Walking this spec gives each cell its dispatch sequence + BO-management +parameters. Adding a new kernel-group later (e.g., `o_gemv_ffn` for Plan +2-decode) = one new spec file; cell logic is unchanged. + +### 8.3 Standalone (1-launch) ELFs + +Same approach as Plan 1: thin wrappers around existing sub-builders in +`multi_launch_builder/rms_gemms_rope_multi.py` and +`multi_launch_builder/o_ffn_multi.py`, called with single-launch stitch +specs at production prefill shapes (seq=2048). + +The wrappers should match the same `_extract_public_func_name` pattern Plan +1 settled on for `instance_name` — the standalone ELF's exported symbol +must be the actual MLIR public func name, not the cache key. + +### 8.4 Cell-specific harness (parameterized) + +| Cell | Implementation | +|---|---| +| **A** | Walks `spec.sub_launches` in order, invokes each via `cache.load_and_run(naive=True)` (Plan 1's `KernelCache.naive=True` mode). Per the §3 cross-group note: between rms_gemms_rope and FA, and FA and o_ffn, intermediates flow through host (extract → write to next group's input arrays). | +| **B** | Same as A, but a `preload(spec, weights_per_layer)` pass writes weights into per-layer BOs first (per-layer `bo_key`). Subsequent calls use `static_input_indices=spec.weight_slots`. | +| **C** | Same as B, but after preload, walk `spec.baton_links` and call `_share_bo` (Plan 1's helper, lifted into `prefill/cells/common.py` if needed) to alias intermediate BOs across sub-launches within each group. Use `intermediate_indices` for both producer-output and consumer-input slots. | +| **D** | Wrapper around production `build_rms_gemms_rope_module(seq_len=2048, ...)` and `build_o_ffn_module(seq_len=2048, ...)`. Two `cache.load_and_run` calls per layer (one per merged ELF). Unpacks output by slot index per Plan 1's lesson. | +| **flash_attn_const** | Compiles FA via existing `flash_attention/kernel_fusion_based/attn_npu2_seqfirst.py:build_module` with the same kwargs production uses. Invocation is identical in every cell — same `bo_key`, same `output_indices`, same FA-input/output extraction pattern. | +| **multi_layer** | Wraps a per-layer triple in a 16-layer loop. Threads `x_in[layer_i+1] = o_ffn_output[layer_i]`. Used by both single-layer and 16-layer orchestrator scopes. | + +### 8.5 Validation + +Reuses Plan 1's `programming_examples/llama32_1b/ablation/validate.py` +verbatim (read-only import). Two golden npz files + per-cell validation gate ++ cross-cell consistency re-check (per §5). Failed cells suppress timing. + +### 8.6 Orchestrator scopes + +``` +run_ablation.py supports two timing scopes: + --scope=single-layer 5 trials × 1-layer cell call + --scope=16-layer 5 trials × 16-layer cell call + --scope=both (default) both above; report both numbers +``` + +Both scopes run the same validation gate (layer-0 against golden) before +timing. + +## 9. Statistical methodology + +- **5 trials per cell × scope**, drop trial 1 (warmup), report median + (min, max). +- All `xrt.run()` invocations wrapped in `flock -x -w 1800 + /tmp/mlir-air-npu.lock` per `CLAUDE.md`. +- 16-layer trials may exhibit higher variance than single-layer (more + opportunity for NPU jitter). Budget for 10 trials on 16-layer scope if + median ± range > 5 %. + +## 10. Deliverable: `programming_examples/llama32_1b/ablation/prefill/` + +Self-contained mini-project with its own `make all` entry point: + +``` +make compile # one-time, ~10-15 min (16 ELFs at seq=2048 + FA) +make regen-golden # one-shot, after Cell D changes +make run # all 4 cells × both scopes, emit JSON +make report # markdown report +make all # compile + run + report +make clean # wipe build/ +``` + +The auto-generated report includes: +- Validation badge table (per cell PASS/FAIL). +- Single-layer per-call timing table (per cell × per kernel-group). +- 16-layer total wall-time table (per cell, with comparison to `profile.md`'s 1.27 s). +- Marginal delta tables (per kernel-group AND aggregated). +- Per-launch breakdown extracted from Cell C (6 lines for rms_gemms_rope, 8 lines for o_ffn). +- Host-overhead share per cell. +- Comparison against `profile.md`'s "Key Optimizations" table claims. + +A pointer is added to `programming_examples/llama32_1b/ablation/README.md` +(Plan 1's README) cross-linking to this study. + +## 11. Out of scope (explicitly) + +- **Plan 2-decode**: Decode `o_gemv_ffn` ablation (4 cells × 8 sub-launches). Same methodology; deferred to next sub-plan. +- **Plan 2-lm-head**: LM Head L1 (production 8-merged) vs L8 (8 separate `xrt.run()`) mini-study. Orthogonal homogeneous-merging characterization. +- **Plan 2.5 (potential)**: Cross-kernel-group BO sharing (rms_gemms_rope's `q_roped/k_roped/v` outputs aliased to FA's input BOs; FA's `attn_out` aliased to o_ffn's residual-add input). Production doesn't do this; could be a separate optimization study. +- **Tier A #4 / #5** from the master spec (last-token LM Head; CPU vs NPU LM Head GEMV). +- **All Tier B** (seq-first FA/RoPE; FA vs naive attention; CPU vs NPU decode attention; `omit_pingpong` toggling; LM Head partition sweep beyond {1, 8}). +- **Real HuggingFace weights.** Synthetic seed=42 only. + +## 12. Isolation strategy + +### 12.1 Worktree + +``` +git worktree add ../mlir-air-ablation-plan2 -b llama32-1b-ablation-plan2-prefill +``` + +The user's primary checkout at `/home/jiajli/apps/mlir-air/` (currently on +`llama-3.2-1B-devel`) is not perturbed. Plan 2 work happens in +`../mlir-air-ablation-plan2/` on its own branch. The user can review Plan 1 +files in the primary checkout while Plan 2 develops. + +### 12.2 File-level guarantee + +Plan 2 code only **imports** from Plan 1's read-only modules +(`programming_examples/llama32_1b/ablation/cells/common.py:compile_standalone_kernels`, +`ablation/validate.py`, `ablation/cells/baton.py:_share_bo` may be lifted into +prefill/cells/common.py if needed but the original is not modified). + +Production code (`programming_examples/llama32_1b/kernel_builder/cache.py`) +already has the `naive=True` mode from Plan 1; Plan 2 introduces no further +changes to it. + +### 12.3 Merge plan + +After Plan 2 is implemented and tested, the worktree branch is merged into +`llama-3.2-1B-devel` (or a parent branch as the user designates). Because +Plan 2 only adds files and never modifies existing ones, the merge is +fast-forward / no-conflict. + +## 13. Risks + +| Risk | Mitigation | +|---|---| +| 14 standalone ELFs at seq=2048 + FA = ~16 ELFs to compile, ~10–15 min one-time | Cached to disk after first compile; documented in README. | +| 16 layers × multiple weight tensors at seq=2048 ≈ 1 GB resident BO memory | Verified to fit on test machine; if not, fall back to 1-layer scope only. | +| Parameterized cell logic harder to debug than Plan 1's hardcoded form | KernelGroupSpec is a frozen dataclass; cells walk it mechanically. Unit tests on a mock cache verify each cell's call sequence per spec. | +| FA ELF first-time compile is ~46 s per `profile.md` | Compiled once, cached. Verified once via FA's own validation. | +| Cell A high BO traffic on 16-layer scope may dominate variance | Bump trial count to 10 for Cell A 16-layer if 5-run median ± range > 5 %. | +| Cross-cell consistency re-check (§5) may fail after long 16-layer runs if BO recycle has bugs | If failure occurs, suspend cell and report — don't trust timing. | + +## 14. Success criteria + +The study succeeds if it produces: + +1. A reproducible harness (single `make all` from + `programming_examples/llama32_1b/ablation/prefill/`). +2. Every reported cell passes the §5 correctness gate (per-cell + cross-cell + bit-exact). +3. Numerical attribution for #1, #2, #3 in the prefill regime, per + kernel-group AND aggregated. +4. Per-launch breakdown for both prefill kernel-groups (from Cell C). +5. Host-overhead share for each cell (single-layer and 16-layer scopes). +6. 16-layer total prefill wall-time numbers with confirmed (or corrected) + comparison to `profile.md`'s 1.27 s headline. +7. Plan 1 files unmodified (`git diff main..plan2-branch` shows only file + additions in `ablation/prefill/`). diff --git a/programming_examples/llama32_1b/ablation/docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md b/programming_examples/llama32_1b/ablation/docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md new file mode 100644 index 000000000..24516d485 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/docs/specs/2026-05-12-llama32-1b-ablation-plan2-fulldecode-design.md @@ -0,0 +1,233 @@ +# Llama-3.2-1B NPU2 Ablation Study — Plan 2 (Full Decode) Design + +**Status**: Design (pending implementation plan) +**Date**: 2026-05-12 +**Branch**: implementation on a fresh worktree from `llama-3.2-1B-devel` +**Scope**: `programming_examples/llama32_1b/ablation/decode/` (new self-contained subdir) + +**Companion docs:** +- Master ablation spec: [`2026-05-07-llama32-1b-ablation-study-design.md`](2026-05-07-llama32-1b-ablation-study-design.md) +- Plan 0 (decode `rms_gemv_rope` pilot) plan: [`../plans/2026-05-07-llama32-1b-ablation-decode-pilot.md`](../plans/2026-05-07-llama32-1b-ablation-decode-pilot.md) +- Plan 1 (full prefill) spec: [`2026-05-07-llama32-1b-ablation-plan2-prefill-design.md`](2026-05-07-llama32-1b-ablation-plan2-prefill-design.md) +- Plan 1 (full prefill) plan: [`../plans/2026-05-07-llama32-1b-ablation-plan2-prefill.md`](../plans/2026-05-07-llama32-1b-ablation-plan2-prefill.md) +- ABLATION_STUDY.html Part 5 (Plan 2 design summary, audience-facing): `programming_examples/llama32_1b/docs/ABLATION_STUDY.html#plan2-status` +- Production profile: `programming_examples/llama32_1b/docs/profile.md` + +--- + +## 1. Goal + +Apply the proven 4-cell ablation methodology — validated by Plan 0 (decode `rms_gemv_rope` pilot, A→D = 2.75×) and Plan 1 (full prefill, A→D = 1.56×, Cell D = 1.13 s ≈ profile.md's 1.27 s) — to the **full decode** dispatch pipeline. Three decode kernel-groups are in scope: + +- `rms_gemv_rope` (6 sub-launches at single-token GEMV shapes) — already pilot-tested in Plan 0 +- `o_gemv_ffn` (8 sub-launches at single-token GEMV shapes) — new in this plan +- `lm_head_gemv` (8 partitions stitched in 1 ELF, 1 NPU call/token) — held INVARIANT across cells (rationale §4) + +The CPU-side `decode_attention_cpu` is also held invariant (it's CPU code; nothing to ablate). FlashAttention's NPU decode path is OUT OF SCOPE — production decode uses CPU attention at head_dim=64 because the NPU FA path has overhead at single-query workloads. + +**Two scopes per cell:** +1. **Per-kernel-group single-call timings** for each of `rms_gemv_rope` and `o_gemv_ffn` — fast iteration and per-launch breakdown extraction (matches Plan 0/1's reporting style). +2. **Per-token full-pipeline wall time** = 16 layers × (rms_gemv_rope + decode_attn_cpu + o_gemv_ffn) + final RMSNorm + lm_head_gemv + argmax. Headline number directly comparable to `profile.md`'s per-token decode latency. + +Plan 2 produces the comprehensive end-to-end decode ablation report. After Plan 2, all three production phases (single-kernel-group decode, end-to-end prefill, end-to-end decode) have controlled measurements. + +## 2. Optimizations under study + +Same three optimizations as Plan 0/1, applied to the decode kernel-groups: + +| ID | Optimization | Production behavior in decode | +|----|--------------|-------------------------------| +| **#1** | Multi-launch ELF | Per layer per token: 6 sub-launches stitched into `rms_gemv_rope.elf`, 8 stitched into `o_gemv_ffn.elf`. Two `xrt.run()` per layer (plus the CPU attention step). Per token: 16 × 2 + 1 (LM head) = **33 NPU calls**. | +| **#2** | Per-layer weight BOs (`static_input_indices`) | All 16 layers' decode weights pre-loaded into per-layer BOs once during `prepare_runtime`; `static_input_indices` skips re-write on subsequent calls. Same `bo_key=f"name_L{layer_idx}"` trick as production. | +| **#3** | `intermediate_indices` | Buffers the kernel will overwrite are not host-written on subsequent calls. For Cell C, intermediate BOs are also explicitly aliased across separate `xrt.run()` calls within each kernel-group via `_share_bo` (mirror Plan 0/1). | + +These are the same three flags exercised in Plan 0/1; what changes is the dispatch envelope (per-token loop instead of single dispatch or 16-layer prefill loop) and the addition of `o_gemv_ffn` as a second cell-ablated kernel-group. + +## 3. Experimental design — the 4-cell ladder + +The ladder applies to the **decode per-layer triple** (rms_gemv_rope + decode_attn_cpu + o_gemv_ffn). The CPU attention is invariant across cells. Cells differ only in how they dispatch the within-kernel-group sub-launches of `rms_gemv_rope` and `o_gemv_ffn`. LM head is invariant (production-merged in every cell). + +| Cell | Description | Marginal change | Isolates | +|------|-------------|-----------------|----------| +| **A** Naive no-merge | Each sub-launch as separate `xrt.run()`: 6 calls for `rms_gemv_rope` + 1 CPU attn + 8 calls for `o_gemv_ffn` = **14 NPU calls per layer**. Plus 8 calls for `lm_head_gemv` per token (held merged here per §4 rationale; if also un-merged, would be 22). Per token: 14 × 16 + 8 = **232 NPU calls (with LM head merged) / 232 + 7 = 239 (with LM head un-merged)**. Production-decode-uses-merged baseline: **232 calls/token in Cell A**. Host round-trip on every intermediate. Weights re-uploaded every call. | (baseline) | — | +| **B** + per-layer weight BOs | Same as A, but weights pre-loaded into per-layer BOs once; `static_input_indices` skips re-write. Same NPU call count. | +#2 | A→B = #2 alone | +| **C** + shared intermediate BOs | Same as B, but intermediate BOs are aliased across separate `xrt.run()` calls **within each kernel-group**. Cross-kernel-group transitions (rms_gemv_rope → CPU attn → o_gemv_ffn) still go through host — matches production. Same NPU call count. | +#3 (intermediate-BO sharing across separate `xrt.run()` calls within each group) | B→C = #3 alone | +| **D** Multi-launch merged | Production: `rms_gemv_rope`'s 6 stitched into one ELF, `o_gemv_ffn`'s 8 stitched into one ELF. **2 NPU calls per layer + 1 LM head per token = 33 NPU calls/token** (matches profile.md). | +#1 | C→D = pure #1 (XRT dispatch saved by group-merging) | + +### Reported claims + +| Reported number | What it answers | +|-----------------|------------------| +| **A→D (per-token wall)** | Total naïve→production speedup for decode | +| **C→D** | Pure multi-launch merging effect for decode | +| **A→B** | #2 contribution alone in decode | +| **B→C** | #3 contribution alone in decode | +| **Per-kernel-group medians** | Per-call wall time for each of `rms_gemv_rope` and `o_gemv_ffn` across cells (analogous to Plan 1's per-call breakdown table) | +| **Cell D per-token wall vs `profile.md`** | Confirms (or corrects) the production decode per-token number from a clean ablation | +| **Cross-comparison vs Plan 0** | Does the single-kernel-group finding (Plan 0: #2 dominates at 1.60×) hold at full per-token end-to-end scale, or shift when `o_gemv_ffn` is added to the ablation envelope? | + +## 4. Invariants across all cells + +To ensure cell-to-cell deltas reflect only the within-kernel-group dispatch strategy: + +- **Same C++ kernels, shapes, weights, prompt seed.** Bit-exact output validated against Cell D for layer 0 (one validation gate per kernel-group: `rms_gemv_rope` and `o_gemv_ffn`). +- **`decode_attention_cpu` is the same Python/numpy function in every cell.** Its CPU work is ~constant across cells (same input shapes, same KV cache state at the timed window's start) — see §6 for state management. +- **`lm_head_gemv.elf` is held INVARIANT (production-merged) in every cell.** Rationale: it's structurally one `xrt.run()` with 8 stitched launches; production already merges; it is invariant in the same sense `flash_attn.elf` is invariant in Plan 1. Reporting it as a separate "fixed cost per token" line keeps the 4 cells comparable on the parts that DO change. If a follow-up Plan 2.5 wants to ablate LM head dispatch separately (option (b) or (c) from the ABLATION_STUDY.html design), it can be done on top of Plan 2's results. +- **Same KV cache initial state at the start of every cell's timed window.** A fixed-seed pre-fill of `prompt_len = 7` populates layer 0..15 cache slots 0..6; `current_pos = 7` at trial start. Each trial generates exactly ONE decode token. After the trial, the cache slot at position 7 is NOT preserved across trials — re-initialized per trial so each trial measures the same starting state. +- **NPU exclusive-locked**: `flock -x -w 1800 /tmp/mlir-air-npu.lock` mirrors Plan 0/1. +- **Synthetic deterministic inputs** from numpy `seed=42` (mirrors Plan 0/1 exactly). + +## 5. Timing protocol + +**Per cell:** +1. **Preload** (not timed): build cache state, pre-load weights into per-layer BOs (Cells B/C/D), allocate intermediate BOs (Cell C aliasing wired here). +2. **5 timed trials**, each generating exactly **1 decode token** starting from the same KV cache state (`current_pos = 7`). +3. **Drop trial 1 as warmup** (XRT context warmup, instruction-cache fill, BO-mapping cache fill). +4. **Report median + (min, max) over trials 2–5** per cell. + +**Why single-token per trial (not 32-token loops):** +- Per-token decode wall time has a position-dependent component: `decode_attention_cpu` reads `[0:current_pos+1]` of the KV cache, so its CPU work scales linearly with `current_pos`. Generating 32 tokens means each token's wall time grows slightly with position, contaminating the dispatch-only comparison we care about. +- Single-token-at-fixed-position keeps the CPU attention work CONSTANT across trials and across cells. +- Trade-off: 5 trials × 1 token gives less smoothing than 5 trials × 32 tokens. Mitigation: warmup-trial-drop captures the first-call overhead; trials 2-5 should be very tight (similar to Plan 0/1's within-cell variance of <1% of mean). + +## 6. KV cache state management + +Each cell sees identical cache state at the start of each timed trial: + +``` +At trial start: + k_cache[0..15, :, 0:7, :] = synthetic-pre-filled (seed=42) + v_cache[0..15, :, 0:7, :] = synthetic-pre-filled (seed=42) + k_cache[0..15, :, 7:, :] = zeros + v_cache[0..15, :, 7:, :] = zeros + current_pos = 7 + +During trial: + For L in 0..15: + rms_gemv_rope (NPU) # produces q_roped, k_roped, v + decode_attention_cpu (CPU) # reads k/v_cache[L, :, 0:8, :]; writes k/v at slot 7 + o_gemv_ffn (NPU) # produces next-layer x_decode + final_rmsnorm (CPU, single row) + lm_head_gemv (NPU) + argmax (CPU) + +At trial end: + Reset k_cache[L, :, 7, :] = 0 and v_cache[L, :, 7, :] = 0 for all L. + (Or more simply: reset entire cache from the saved pre-filled state.) +``` + +The cache reset between trials is a host-side numpy array assignment — negligible cost outside the timed window. + +## 7. Validation gate + +Mirror Plan 0/1: every cell must produce **byte-identical** outputs for both `rms_gemv_rope` and `o_gemv_ffn` against committed Cell D goldens, on the seed=42 synthetic input at `current_pos = 7`. Cells failing the gate have their timing suppressed in the report. + +Two committed `golden_*.npz` fixtures (one per kernel-group), regenerated by Cell D's harness if production kernels change. The validation step compares all six rms_gemv_rope outputs (`normed, q, k, v, q_roped, k_roped`) and the eight o_gemv_ffn outputs (intermediate buffers + final layer output). For LM head: validate that the final argmax token id matches across all four cells (single-integer comparison; bit-exact). + +## 8. File structure (proposed) + +All paths under `programming_examples/llama32_1b/ablation/decode/` (new sibling of `ablation/prefill/`). + +| File | Responsibility | Mirrors | +|------|----------------|---------| +| `__init__.py` | Package marker | — | +| `README.md` | Methodology, run instructions, results, reproducibility | Plan 1's `README.md` | +| `Makefile` | `make compile / regen-golden / run / report / all / clean` | Plan 1's `Makefile` | +| `specs/__init__.py` | Package marker | — | +| `specs/kernel_group.py` | `SubLaunchSpec`, `BatonLink`, `KernelGroupSpec` (or re-export from `ablation/prefill/specs/kernel_group.py` to share definitions) | Plan 1 | +| `specs/rms_gemv_rope.py` | `KernelGroupSpec` instance for the 6-launch decode attention pre-block | Plan 1's `specs/rms_gemms_rope.py` | +| `specs/o_gemv_ffn.py` | `KernelGroupSpec` instance for the 8-launch decode FFN block | Plan 1's `specs/o_ffn.py` | +| `standalone_builders/__init__.py` | Package marker | — | +| `standalone_builders/rms_gemv_rope.py` | Re-export Plan 0's existing `STANDALONES` registry (already in `ablation/standalone_builders/decode_rms_gemv_rope.py`) | Plan 0 | +| `standalone_builders/o_gemv_ffn.py` | 8 single-launch builder wrappers + `STANDALONES` registry — NEW | Plan 1's `standalone_builders/o_ffn.py` | +| `cells/__init__.py` | Package marker | — | +| `cells/common.py` | `compile_standalone_kernels` (parameterized), `_share_bo`, `_extract_public_func_name`, helpers — re-export or copy from Plan 1 | Plan 1's `cells/common.py` | +| `cells/cell_a_naive.py` | Parameterized Cell A — walks a `KernelGroupSpec` with `naive=True` | Plan 1 | +| `cells/cell_b_static.py` | Parameterized Cell B — preload weights, `static_input_indices` | Plan 1 | +| `cells/cell_c_charitable.py` | Parameterized Cell C — preload + alias intermediate BOs per `spec.baton_links` | Plan 1 | +| `cells/cell_d_merged.py` | Wraps production `build_rms_gemv_rope_module` and `build_o_gemv_ffn_module` from `multi_launch_builder/` | Plan 1 | +| `cells/decode_attn_const.py` | CPU attention invariant: same Python function in every cell | Plan 1's `flash_attn_const.py` | +| `cells/lm_head_const.py` | LM head invariant: production-merged 8-partition GEMV in every cell | NEW (Plan 1's FA invariant pattern) | +| `cells/per_token_loop.py` | Wraps a per-layer triple in a 16-layer loop, then runs final RMSNorm + LM head + argmax. **The end-to-end timed unit.** | Plan 1's `cells/multi_layer.py` | +| `golden/__init__.py` | Package marker | — | +| `golden/regen_golden.py` | One-shot Cell-D run for layer 0; dumps two npz fixtures + meta json | Plan 1 | +| `golden/golden_rms_gemv_rope_decode.npz` | Committed bit-exact reference (Cell D, layer 0, seed=42, current_pos=7) | Plan 1 | +| `golden/golden_o_gemv_ffn_decode.npz` | Committed bit-exact reference for o_gemv_ffn | Plan 1 | +| `golden/golden_meta.json` | Hashes, shapes, config, prompt_len, current_pos | Plan 1 | +| `run_ablation.py` | Orchestrator: validate → time × {per-call, per-token} × 4 cells, emit JSON | Plan 1 | +| `analyze.py` | JSON → markdown report | Plan 1 | +| `tests/__init__.py` | Package marker | — | +| `tests/conftest.py` | Pytest sys.path setup | Plan 1 | +| `tests/test_kernel_group_spec.py` | Dataclass invariants (NPU-free) | Plan 1 (or just import from Plan 1's tests) | +| `tests/test_parameterized_cells.py` | Mock-cache tests verifying each cell walks its spec correctly (NPU-free) | Plan 1 | +| `tests/test_validation_gate.py` | Tests against the two new decode goldens | Plan 1 | +| `tests/test_kv_cache_state.py` | NEW: verifies cache initialization + per-trial reset is deterministic | NEW | + +**Files NOT touched** (Plan 0/1 isolation guarantee): every file under `programming_examples/llama32_1b/ablation/` outside `decode/`. Production code (`programming_examples/llama32_1b/{kernel_builder,multi_launch_builder}/`) read-only — only imported. + +## 9. Open design decisions (RESOLVED) + +For traceability, the 7 questions raised in `ABLATION_STUDY.html#plan2-validation` and their answers (per user discussion 2026-05-12): + +| # | Question | Decision | +|---|----------|----------| +| 1 | How many tokens to generate per timed run? | **1 decode token per trial × 5 trials, drop trial 1 (warmup), median over trials 2-5.** Avoids position-dependent CPU attention growth contaminating the dispatch comparison. | +| 2 | Should LM head be its own cell ladder? | **Hold INVARIANT** (production-merged in every cell). Mirrors Plan 1's FA treatment. Defer separate LM-head ablation to a possible Plan 2.5. | +| 3 | KV cache state initialization | **Deterministic synthetic pre-fill of 7 tokens** from `seed=42`; reset between trials. | +| 4 | Where does `decode_attention_cpu` wall time get attributed? | **Counted in per-token total AND reported separately as a "CPU floor" line** (mirrors Plan 1's FA reporting). | +| 5 | Predicted findings | **Not in the spec or plan.** Forecasts become bias when running. Report only after measurement. | +| 6 | Production CPU-attention or experimental NPU FA decode? | **Production CPU-attention path only.** That's what `profile.md` reflects. | +| 7 | Where does the harness live? | **`programming_examples/llama32_1b/ablation/decode/`** (new sibling of `ablation/prefill/`). | + +## 10. Out of scope + +- **NPU FlashAttention decode path** (head_dim=64). Production uses CPU; this plan doesn't ablate the alternative. +- **LM Head L1/L8 mini-study** (whether to use 1-launch or 8-partition LM head). Held invariant in this plan; can be a follow-up Plan 2.5. +- **Cross-kernel-group BO aliasing** (rms_gemv_rope output BO → CPU attention input → o_gemv_ffn input). This is the C2 future-work entry in IMPLEMENTATION_GUIDE.html. Cross-group goes through host in every cell, matching production. +- **Tokens beyond a single fixed `current_pos`.** Single-token-at-fixed-position is intentional (§5). +- **Real HuggingFace weights.** Synthetic seed=42 only — same justification as Plan 0/1. +- **Numerical-precision study vs an HF / F32 reference.** That belongs to the production verify subsystem (`make verify` for the top-k token gate, `make diagnosis` for per-layer cosine), not duplicated here. + +## 11. Risk register + +| Risk | Likelihood | Impact | Mitigation | +|------|------------|--------|------------| +| Single-token timing has high variance because no per-token smoothing | Medium | Medium | Warmup-drop + 5 trials usually suffices (Plan 0 saw <1% within-cell variance with the same approach). If trials 2-5 spread is >5% of median, increase to 9 trials (drop 1). | +| `o_gemv_ffn` standalone builder for cell A/B/C is more complex than `rms_gemv_rope`'s (8 sub-launches incl. SwiGLU + Down GEMV at K=8192) | High | Medium | Carefully reuse Plan 1's `standalone_builders/o_ffn.py` patterns; the kernel-group structure parallels but with GEMV instead of GEMM and the special `mv_k8192.o` for the Down step. Allow extra time for this task. | +| Bit-exact validation across 32 generated tokens (if we extend later) might fail because cache state evolves identically only if every cell sees the same input bytes at every position | Low (since we use 1 token) | Low | Single-token approach sidesteps this entirely. If we later extend to multi-token, validation must hash all generated outputs, not just the first. | +| LM head's per-token wall time is non-trivial (~14 ms typical), so even though it's invariant it shifts the per-token total significantly | Low | Low | Report the LM head as a separate fixed-cost line (mirrors Plan 1's FA reporting). Doesn't bias cell-to-cell deltas. | +| Goldens become stale when production kernels are recompiled (e.g., after a Peano upgrade) | Medium | Medium | Same as Plan 0/1: `make regen-golden` documented; validation gate fails loudly so divergence is visible. | +| KV cache state between trials accidentally drifts (e.g., partial reset bug) | Low | High (would invalidate timing if cells see different input data) | `tests/test_kv_cache_state.py` verifies reset determinism BEFORE timing trials run. | + +## 12. Reproducibility guarantee + +``` +git clone && git checkout +cd programming_examples/llama32_1b/ablation/decode +make clean +make all # compile (~5 min) + run (~2 min, NPU-locked) + report +``` + +Expected output (5 trials per cell, drop trial 1, median + range): +``` + Cell A: PASS per-token median=~XX ms range=[~YY, ~ZZ]ms + Cell B: PASS per-token median=~XX ms range=[~YY, ~ZZ]ms + Cell C: PASS per-token median=~XX ms range=[~YY, ~ZZ]ms + Cell D: PASS per-token median=~XX ms range=[~YY, ~ZZ]ms +``` + +(Numbers TBD by implementation. Cell D per-token median should be in the ballpark of `profile.md`'s decode latency, modulo ~1-2 ms of host steps not in the timed window.) + +NPU-free unit tests: `python3 -m pytest tests/ -v` should report 8+ passed. + +## 13. Companion ABLATION_STUDY.html updates (post-implementation) + +After Plan 2 is implemented and measured, update `programming_examples/llama32_1b/docs/ABLATION_STUDY.html`: + +- Section 5.1 (status): change from "📋 PLANNED" to "✅ Implemented + measured" +- Add new Section 5.4 (Results — Plan 2: full decode), parallel to Sections 3.3 and 4.3 +- Update Section 6.1 (cross-comparison): replace "decode vs. prefill (so far)" with three-way comparison (Plan 0 vs Plan 1 vs Plan 2) +- Update Quick recap at bottom +- Update sidebar nav if needed + +These updates are part of the Plan 2 implementation plan, not a separate plan. diff --git a/programming_examples/llama32_1b/ablation/prefill/.gitignore b/programming_examples/llama32_1b/ablation/prefill/.gitignore new file mode 100644 index 000000000..f0c28021f --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/.gitignore @@ -0,0 +1,16 @@ +# Build / kernel cache artifacts +build/ +standalone_cache/ +air_project/ +__pycache__/ +*.pyc + +# Compiled NPU kernel objects (generated by Peano during make compile) +*.o +*.elf +*.mlir +*.insts.bin + +# Run artifacts (regenerated each `make run`) +results_*.json +report_*.md diff --git a/programming_examples/llama32_1b/ablation/prefill/Makefile b/programming_examples/llama32_1b/ablation/prefill/Makefile new file mode 100644 index 000000000..0fb5429cc --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/Makefile @@ -0,0 +1,34 @@ +# Llama-3.2-1B prefill ablation harness +# +# make compile — compile all standalone ELFs + Cell D's 2 merged ELFs + FA (~10-15 min, cached) +# make regen-golden — regenerate committed golden fixtures (rare; only after Cell D changes) +# make run — run all 4 cells × 2 kernel-groups × both scopes, emit JSON +# make report — generate markdown report from latest results JSON +# make all — compile + run + report +# make clean — wipe build/ + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +BUILD := build + +.PHONY: help compile regen-golden run report all clean + +help: + @echo "make compile | regen-golden | run | report | all | clean" + +compile: + @mkdir -p $(BUILD) + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 -m cells.common + +regen-golden: compile + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 $(srcdir)/golden/regen_golden.py + +run: compile + cd $(BUILD) && PYTHONPATH=$(srcdir):$(srcdir)/..:$(srcdir)/../..:$(srcdir)/../../..:$$PYTHONPATH flock -x -w 1800 /tmp/mlir-air-npu.lock python3 $(srcdir)/run_ablation.py --out results_latest.json + +report: + cd $(BUILD) && python3 $(srcdir)/analyze.py results_latest.json --out report_latest.md && cat report_latest.md + +all: compile run report + +clean: + rm -rf $(BUILD) diff --git a/programming_examples/llama32_1b/ablation/prefill/README.md b/programming_examples/llama32_1b/ablation/prefill/README.md new file mode 100644 index 000000000..5a0261185 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/README.md @@ -0,0 +1,116 @@ +# Llama-3.2-1B Prefill Ablation (Plan 2) + +Bit-exact 4-cell ablation of the production **prefill** pipeline: +`rms_gemms_rope` (6 launches) + FlashAttention (held constant) + `o_ffn` +(8 launches), at seq=2048 GEMM shapes, both single-layer and full 16-layer +scopes. + +Companion docs: +- Plan 2 spec: [`../docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md`](../docs/specs/2026-05-07-llama32-1b-ablation-plan2-prefill-design.md) +- Plan 1 (decode pilot): [`../README.md`](../README.md) +- Production profile: [`../../docs/profile.md`](../../docs/profile.md) + +## What this measures + +Four cells, identical computation, different dispatch strategy: + +| Cell | What changes within each kernel-group | Adds | +|------|---------------------------------------|------| +| A | 6 + 8 separate `xrt.run()` per layer, host round-trip on every intermediate | (baseline) | +| B | + per-layer weight BOs (`static_input_indices`) | #2 | +| C | + shared intermediate BOs across separate `xrt.run()` calls | #3 | +| D | + multi-launch merging (production: 6→1 + 8→1 ELF per layer) | #1 | + +FA is held constant per spec (un-mergeable). Cross-kernel-group transfers +(rms→FA, FA→o_ffn) go through host in every cell — matches production. + +## Pilot measurements (final smoke run) + +### 16-layer total wall — comparable to profile.md's 1.27 s + +| Cell | Median (s) | Range | Δ vs prev | Speedup | vs profile.md | +|---|---|---|---|---|---| +| A — Naive | 1.754 | [1.751, 1.755] | — | (baseline) | 1.38× slower | +| B — + per-layer weight BOs (#2) | 1.589 | [1.584, 1.594] | A→B = +0.165 s | **1.10×** | 1.25× slower | +| C — + shared intermediate BOs (#3) | 1.212 | [1.212, 1.222] | B→C = +0.377 s | **1.31×** | 0.95× faster | +| D — + multi-launch merging (#1) | 1.125 | [1.124, 1.127] | C→D = +0.087 s | **1.08×** | 0.89× faster | +| | | | **A→D total** | **1.56×** | | + +5 trials per cell, drop trial 1 (warmup), median + (min, max) over remaining 4. +**Cell D = 1.125 s ≈ profile.md's 1.27 s** (small overshoot from embedding lookup, KV cache extraction, etc. not in this harness). + +### Single-layer per-call medians (ms) + +| Cell | rms_gemms_rope | o_ffn | +|---|---|---| +| A | 14.99 | 75.05 | +| B | 12.52 | 64.67 | +| C | 9.77 | 45.01 | +| D | 7.43 | 40.99 | + +Per-call speedups: rms_gemms_rope A→D = 2.02×, o_ffn A→D = 1.83×. + +### Findings + +- **#3 (shared intermediate BOs) dominates in prefill** at 1.31× — *opposite of decode* where #3 ≈ 1.0×. In prefill, per-launch intermediates are large (e.g. 8 MB GEMM outputs at seq=2048) and the bandwidth saved by aliasing BOs is significant. +- **#2 (per-layer weight BOs) is small in prefill** (1.10×) — weights are big but the per-call NPU compute is much bigger, so weight-transfer cost is a smaller fraction of total time. (Decode is the opposite: weights dominate because per-call compute is small.) +- **Pure multi-launch merging (#1) is small in prefill** (1.08×) — same intuition: dispatch overhead matters less when per-call work is large. +- **Total A→D = 1.56× speedup** for prefill — smaller than decode's 2.75× because per-call work is much bigger, so dispatch-related overheads are a smaller share. +- **All 4 cells produce bit-identical output bytes** (validated against committed golden fixtures from Cell D), so timing differences are purely dispatch-related. + +## Quick start + +``` +make compile # one-time, ~10-15 min for 14 standalone ELFs + 2 merged + FA +make run # 5 trials × both scopes × all 4 cells (~5-10 min) +make report # markdown report +``` + +## Validation gate + +Every cell's per-kernel-group output must match the committed `golden/*.npz` +fixtures bit-exactly (synthetic numpy seed=42 inputs). Cells failing the +gate suppress their timing in the report. + +## Reproducibility + +``` +cd programming_examples/llama32_1b/ablation/prefill +make clean && make all +``` + +The 16-layer Cell D total wall time should be in the ballpark of +`profile.md`'s **1.27 s** production headline. The marginal deltas table +attributes how much each of optimizations #1, #2, #3 contributes to that +number for prefill specifically. + +Unit tests (NPU-free): + +``` +python3 -m pytest tests/ -v +``` + +Expected: 8 passed (4 KernelGroupSpec + 4 validation gate). + +## Limitations of this plan (Plan 2-decode and Plan 2-lm-head will address) + +- Prefill only — decode `o_gemv_ffn` and the LM Head L1/L8 mini-study are separate plans. +- FA is invariant in every cell. A potential **Plan 2.5** could ablate cross-kernel-group BO sharing (FA's input BOs aliased to rms_gemms_rope's output BOs). +- Synthetic weights only. No HuggingFace. + +## File map + +| Path | Purpose | +|------|---------| +| `specs/kernel_group.py` | Frozen dataclasses (SubLaunchSpec, BatonLink, KernelGroupSpec) | +| `specs/{rms_gemms_rope,o_ffn}.py` | Concrete spec instances | +| `standalone_builders/` | Re-exported STANDALONES registries | +| `cells/cell_{a,b,c,d}_*.py` | Parameterized cell harnesses | +| `cells/flash_attn_const.py` | FA invariant | +| `cells/multi_layer.py` | 16-layer wrapper | +| `cells/common.py` | Compile harness, BO baton-pass helper, public-func-name extractor | +| `golden/` | Two committed npz fixtures + regen script + meta json | +| `validate.py` | Parameterized bit-exact gate | +| `run_ablation.py` | Orchestrator | +| `analyze.py` | Report generator | +| `Makefile` | Convenience targets | diff --git a/programming_examples/llama32_1b/ablation/prefill/__init__.py b/programming_examples/llama32_1b/ablation/prefill/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/prefill/analyze.py b/programming_examples/llama32_1b/ablation/prefill/analyze.py new file mode 100644 index 000000000..c9513a7e4 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/analyze.py @@ -0,0 +1,112 @@ +"""Read prefill results JSON and emit a markdown report. + +Sections: +- Validation badge (per cell × kernel-group) +- Single-layer per-call medians (per cell × kernel-group) +- 16-layer total wall (per cell, with comparison to profile.md's 1.27s) +- Marginal deltas (A→B, B→C, C→D, A→D — per kernel-group AND aggregated) +- Per-launch breakdown extracted from Cell C's single-layer timing data +""" + +import argparse +import json +import os +import time + +PROFILE_MD_HEADLINE_S = 1.27 # production prefill from profile.md + + +def report(results): + cells = results["cells"] + out = [] + out.append("# Prefill Ablation — Report\n") + out.append( + f"Trials: {results['trials']}, config: seq={results['config']['seq_len']}, " + f"emb={results['config']['emb_dim']}, hidden={results['config']['hidden_dim']}\n" + ) + + # Validation table + out.append("## Validation\n") + out.append("| Cell | rms_gemms_rope | o_ffn |") + out.append("|------|----------------|-------|") + for c in ("A", "B", "C", "D"): + rg = cells.get(c, {}).get("rms_gemms_rope", {}).get("validation", "—") + of = cells.get(c, {}).get("o_ffn", {}).get("validation", "—") + out.append(f"| {c} | {rg} | {of} |") + out.append("") + + # Single-layer per-call timing table + out.append("## Single-layer per-call medians (ms)\n") + out.append("| Cell | rms_gemms_rope | o_ffn |") + out.append("|------|----------------|-------|") + for c in ("A", "B", "C", "D"): + rg_s = ( + cells.get(c, {}) + .get("rms_gemms_rope", {}) + .get("single_layer", {}) + .get("median_s") + ) + of_s = cells.get(c, {}).get("o_ffn", {}).get("single_layer", {}).get("median_s") + rg_str = f"{rg_s*1000:.2f}" if rg_s is not None else "—" + of_str = f"{of_s*1000:.2f}" if of_s is not None else "—" + out.append(f"| {c} | {rg_str} | {of_str} |") + out.append("") + + # 16-layer headline table + out.append("## 16-layer total wall (s) — comparable to profile.md's 1.27 s\n") + out.append("| Cell | Median (s) | Min (s) | Max (s) | vs profile.md |") + out.append("|------|------------|---------|---------|---------------|") + for c in ("A", "B", "C", "D"): + e = cells.get(c, {}).get("16_layer", {}) + if not e: + out.append(f"| {c} | — | — | — | — |") + continue + md = e["median_s"] + mn = e["min_s"] + mx = e["max_s"] + ratio = md / PROFILE_MD_HEADLINE_S + out.append(f"| {c} | {md:.3f} | {mn:.3f} | {mx:.3f} | {ratio:.2f}× |") + out.append("") + + # Marginal deltas (16-layer total) + out.append("## Marginal deltas (16-layer total)\n") + + def m(c): + return cells.get(c, {}).get("16_layer", {}).get("median_s") + + pairs = [ + ("A→B (= #2 per-layer weight BOs)", "A", "B"), + ("B→C (= #3 shared intermediate BOs)", "B", "C"), + ("C→D (= #1 multi-launch merging, isolated)", "C", "D"), + ("A→D (= total dispatch-related speedup)", "A", "D"), + ] + out.append("| Comparison | Δ s | Speedup |") + out.append("|------------|-----|---------|") + for label, a, b in pairs: + ma, mb = m(a), m(b) + if ma is None or mb is None: + out.append(f"| {label} | — | — |") + continue + out.append(f"| {label} | {ma - mb:+.3f} | {ma/mb:.2f}× |") + out.append("") + + return "\n".join(out) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("results_json") + ap.add_argument("--out", default=None) + args = ap.parse_args() + with open(args.results_json) as f: + results = json.load(f) + text = report(results) + out = args.out or f"report_prefill_{int(time.time())}.md" + with open(out, "w") as f: + f.write(text) + print(f"Wrote {out}\n") + print(text) + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/__init__.py b/programming_examples/llama32_1b/ablation/prefill/cells/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/cell_a_naive.py b/programming_examples/llama32_1b/ablation/prefill/cells/cell_a_naive.py new file mode 100644 index 000000000..cc5fd19ed --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/cell_a_naive.py @@ -0,0 +1,227 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell A -- Naive no-merge for a generic KernelGroupSpec. + +Walks spec.sub_launches in order. For each sub-launch: + 1. Build the 3-element args list per the spec's slot semantics. + 2. Invoke cache.load_and_run with naive=True (writes everything, + reads everything every call). + 3. Store output in results dict keyed by sub.name. + +Cross-sub-launch data flows via the host (extracted to numpy in a results +dict, then passed to the next call as input). + +naive=True forces load_and_run to: + - set output_indices = list(range(len(inputs))) (read back all slots) + - skip static_input_indices and intermediate_indices optimizations + +The returned result[slot] is always a 1D flat numpy array. Baton-link values +are passed directly as inputs to downstream sub-launches; the BO write uses +raw bytes so 1D vs 2D shape does not matter as long as byte counts match. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.common import compile_standalone_kernels + + +def _output_shape_for(spec_name, sub_name, config): + """Return numpy shape of the output buffer for (spec_name, sub_name). + + The output buffer is allocated as zeros with this shape and passed at + sub.output_slot_in_standalone. The kernel writes into it; load_and_run + returns a 1D flat view (byte-compatible with the 2D shape). + """ + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + n_total = seq * emb + + if spec_name == "rms_gemms_rope": + return { + "rmsnorm": (seq, emb), + "q_gemm": (seq, emb), + "k_gemm": (seq, kv), + "v_gemm": (seq, kv), + "rope_q": (seq, emb), + "rope_k": (seq, kv), + }[sub_name] + + if spec_name == "o_ffn": + return { + "o_gemm": (seq, emb), + "res_add": (seq, emb), + "ffn_rmsnorm": (seq, emb), + "gate_gemm": (seq, hid), + "up_gemm": (seq, hid), + "swiglu": (seq, hid), + "down_gemm": (seq, emb), + "ffn_add": (n_total,), # 1D output (standalone emits 1D; see o_ffn.py) + }[sub_name] + + raise ValueError(f"unknown spec {spec_name!r}") + + +def _static_input_for(spec_name, sub_name, slot, layer_inputs): + """Return the static (weight/LUT/layer-level) array for this slot, or None. + + Returns None when the slot should come from a baton link (upstream + sub-launch output) or from the output buffer. + """ + if spec_name == "rms_gemms_rope": + # Slot conventions (from rms_gemms_rope.py docstring): + # rmsnorm: (x_in[slot0], norm_w[slot1], out[slot2]) + # gemm: (A[slot0], B_weight[slot1], C[slot2]) + # rope_2d: (in[slot0], lut[slot1], out[slot2]) + if sub_name == "rmsnorm": + if slot == 0: + return layer_inputs["x_in"] + if slot == 1: + return layer_inputs["norm_w"] + elif sub_name == "q_gemm": + if slot == 1: + return layer_inputs["wq"] + # slot 0 comes from rmsnorm baton + elif sub_name == "k_gemm": + if slot == 1: + return layer_inputs["wk"] + # slot 0 comes from rmsnorm baton + elif sub_name == "v_gemm": + if slot == 1: + return layer_inputs["wv"] + # slot 0 comes from rmsnorm baton + elif sub_name == "rope_q": + if slot == 1: + return layer_inputs["lut_q"] + # slot 0 comes from q_gemm baton + elif sub_name == "rope_k": + if slot == 1: + return layer_inputs["lut_k"] + # slot 0 comes from k_gemm baton + return None + + if spec_name == "o_ffn": + # Slot conventions (from o_ffn.py docstring): + # gemm: (A[slot0], B_weight[slot1], C[slot2]) + # add_2d_to_2d: (A[slot0], B[slot1], C[slot2]) no weight + # rmsnorm: (x[slot0], w[slot1], out[slot2]) + # swiglu_2d: (gate[slot0], up[slot1], out[slot2]) no weight + # ffn_add: (A[slot0], B[slot1], out[slot2]) no weight + if sub_name == "o_gemm": + if slot == 0: + return layer_inputs["attn_out"] + if slot == 1: + return layer_inputs["wo"] + elif sub_name == "res_add": + # slot0 = proj (from o_gemm baton); slot1 = x_residual (static) + if slot == 1: + return layer_inputs["x_residual"] + elif sub_name == "ffn_rmsnorm": + if slot == 1: + return layer_inputs["ffn_norm_w"] + # slot 0 comes from res_add baton + elif sub_name == "gate_gemm": + if slot == 1: + return layer_inputs["w_gate"] + # slot 0 comes from ffn_rmsnorm baton + elif sub_name == "up_gemm": + if slot == 1: + return layer_inputs["w_up"] + # slot 0 comes from ffn_rmsnorm baton + elif sub_name == "swiglu": + # both slot0 (gate) and slot1 (up) come from batons + pass + elif sub_name == "down_gemm": + if slot == 1: + return layer_inputs["w_down"] + # slot 0 comes from swiglu baton + elif sub_name == "ffn_add": + # slot0 = down (from down_gemm baton); slot1 = res1 (from res_add baton) + pass + return None + + raise ValueError(f"unknown spec {spec_name!r}") + + +def compile_cell_a(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group into cache.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +def run_cell_a(cache, spec, layer_inputs, config, backend_preset, layer_idx=0): + """Run all spec.sub_launches sequentially with naive=True. + + Each sub-launch is a separate xrt.run() call. All host<->device transfers + are done unconditionally (naive=True means no skipping of static or + intermediate buffers). + + Args: + cache: KernelCache with manifested artifacts. + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + layer_inputs: dict of numpy arrays keyed by semantic name + (e.g. "x_in", "norm_w", "wq", "attn_out", etc.). + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + layer_idx: layer index (unused in Cell A, present for API consistency). + + Returns: + dict keyed by sub.name -> 1D flat numpy array of that sub-launch's + output, plus "_wall_s" for total wall time. + """ + # Strip instance_name; compile_cell_a sets it per-kernel. + backend = {**backend_preset} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + out_shape = _output_shape_for(spec.name, sub.name, config) + out_buf = np.zeros(out_shape, dtype=bfloat16) + + # Build the 3-arg list (all standalones have exactly 3 args). + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = out_buf + continue + + # Try static (weight/layer-level) lookup first. + v = _static_input_for(spec.name, sub.name, slot, layer_inputs) + if v is not None: + args[slot] = v + continue + + # Otherwise this slot is fed by an upstream baton link. + for link in spec.baton_links: + if link.consumer_idx == idx and link.consumer_in_slot == slot: + producer_name = spec.sub_launches[link.producer_idx].name + args[slot] = results[producer_name] + break + + assert args[slot] is not None, ( + f"[cell_a] no source found for {spec.name}/{sub.name} slot={slot}. " + f"Check baton_links and _static_input_for." + ) + + kernel_name = f"{spec.name}__{sub.name}" + result = cache.load_and_run( + kernel_name, + backend, + *args, + naive=True, + ) + # naive=True sets output_indices = list(range(3)), so result is a 3-tuple. + # The output is at sub.output_slot_in_standalone. + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/cell_b_static.py b/programming_examples/llama32_1b/ablation/prefill/cells/cell_b_static.py new file mode 100644 index 000000000..517bdebae --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/cell_b_static.py @@ -0,0 +1,244 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell B -- Cell A + per-layer weight BOs + static_input_indices. + +Same dataflow as Cell A (walks spec.sub_launches, threads via baton links), +but weights are pre-loaded once into per-layer BOs during preload phase. +The timed run phase skips the weight host->device sync via static_input_indices. + +Two public phases: + + preload_cell_b(cache, spec, weights_per_layer, config, backend_preset) + Called once before timing. For each (layer_idx, sub_launch): + - Builds a 3-arg list with the actual weight at weight_slot_in_standalone + and dummy zeros at all other slots. + - Calls load_and_run with output_indices=[output_slot], + static_input_indices={weight_slot}, and + bo_key=f"B_{spec.name}_{sub.name}_L{layer_idx}". + Sub-launches with weight_slot_in_standalone=None are skipped (no weight + to preload; those sub-launches just use default bo_key in the timed run). + + run_cell_b(cache, spec, layer_inputs, config, backend_preset, layer_idx=0) + Same loop as Cell A but: + - No naive=True. + - Passes static_input_indices={sub.weight_slot_in_standalone} (or empty + set if None) and output_indices=[sub.output_slot_in_standalone]. + - Passes bo_key=f"B_{spec.name}_{sub.name}_L{layer_idx}" -- must + byte-match the preload bo_key. + +Helpers _output_shape_for and _static_input_for are imported from cell_a_naive +to avoid duplication. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.cell_a_naive import _output_shape_for, _static_input_for +from cells.common import compile_standalone_kernels + + +def _activation_shape_for(spec_name, sub_name, config): + """Return the numpy shape of the activation (non-weight, non-output) input slot. + + This is needed during preload to allocate a correctly-sized dummy BO for the + activation slot. All current standalones have exactly 3 args: + (activation, weight, output). The activation is always at slot 0. + + Shapes must match what _static_input_for / baton links would supply at + run time, because the BO is allocated on the first call (preload) and + reused on subsequent calls (run). A size mismatch raises a ValueError + inside KernelCache.load_and_run when it tries to copy src into the BO. + """ + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + + if spec_name == "rms_gemms_rope": + # All sub-launches: activation at slot 0 is either x_in (seq,emb) or + # the normed/q/k output fed via baton -- all are (seq, emb) or (seq, kv). + return { + # rmsnorm: x_in is (seq, emb) + "rmsnorm": (seq, emb), + # gemms: A input is (seq, emb) -- the normed activation + "q_gemm": (seq, emb), + "k_gemm": (seq, emb), + "v_gemm": (seq, emb), + # ropes: activation slot is the q/k output + "rope_q": (seq, emb), + "rope_k": (seq, kv), + }[sub_name] + + if spec_name == "o_ffn": + return { + # o_gemm: activation = attn_out (seq, emb) + "o_gemm": (seq, emb), + # ffn_rmsnorm: activation = res1 (seq, emb) + "ffn_rmsnorm": (seq, emb), + # gate/up gemms: activation = normed2 (seq, emb) + "gate_gemm": (seq, emb), + "up_gemm": (seq, emb), + # down_gemm: activation = swiglu (seq, hid) + "down_gemm": (seq, hid), + }[sub_name] + + raise ValueError(f"unknown spec {spec_name!r} or sub {sub_name!r}") + + +def compile_cell_b(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group into cache.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +def preload_cell_b(cache, spec, weights_per_layer, config, backend_preset): + """Pre-load per-layer weights into dedicated BOs. + + For each (layer_idx, weights) pair and each sub-launch with a weight slot, + run a one-shot load_and_run that writes the weight into the BO. Subsequent + timed runs reuse the same BO (identified by bo_key) and skip the write. + + Args: + cache: KernelCache with manifested artifacts. + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + weights_per_layer: list of dicts (one per layer), each keyed by semantic + weight name (same keys accepted by _static_input_for / Cell A). + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + for layer_idx, layer_weights in enumerate(weights_per_layer): + for sub in spec.sub_launches: + if sub.weight_slot_in_standalone is None: + # No weight slot -- nothing to preload for this sub-launch. + continue + + out_shape = _output_shape_for(spec.name, sub.name, config) + out_buf = np.zeros(out_shape, dtype=bfloat16) + + # Build the 3-arg list: weight at weight_slot, output at output_slot, + # dummy zeros at remaining slot(s). + args = [None, None, None] + weight_slot = sub.weight_slot_in_standalone + output_slot = sub.output_slot_in_standalone + args[output_slot] = out_buf + + # Retrieve the weight array using the same lookup as Cell A. + weight_arr = _static_input_for( + spec.name, sub.name, weight_slot, layer_weights + ) + assert weight_arr is not None, ( + f"[cell_b preload] _static_input_for returned None for " + f"{spec.name}/{sub.name} slot={weight_slot}. " + f"Check weight keys in weights_per_layer." + ) + args[weight_slot] = weight_arr + + # Fill any remaining slot with a correctly-sized dummy zero array. + # The BO is allocated on this first call and reused in run_cell_b; + # the size must match what the real activation will supply. + for slot in range(3): + if args[slot] is None: + act_shape = _activation_shape_for(spec.name, sub.name, config) + args[slot] = np.zeros(act_shape, dtype=bfloat16) + + bo_key = f"B_{spec.name}_{sub.name}_L{layer_idx}" + kernel_name = f"{spec.name}__{sub.name}" + + cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[output_slot], + static_input_indices={weight_slot}, + bo_key=bo_key, + ) + + +def run_cell_b(cache, spec, layer_inputs, config, backend_preset, layer_idx=0): + """Run all spec.sub_launches sequentially with pre-loaded weight BOs. + + Same dataflow as Cell A (batons via results dict) but: + - Uses static_input_indices={weight_slot} to skip weight write on this call. + - Uses output_indices=[output_slot] instead of naive read-all. + - Uses bo_key matching the preload phase so the same BO set is reused. + + Sub-launches with weight_slot_in_standalone=None (e.g. swiglu, ffn_add) + have no static weight -- they use an empty static_input_indices set and + the same bo_key pattern for BO identity. + + Args: + cache: KernelCache with manifested artifacts. + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + layer_inputs: dict of numpy arrays keyed by semantic name. + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + layer_idx: layer index used to select the right pre-loaded BO set. + + Returns: + dict keyed by sub.name -> 1D flat numpy array of that sub-launch's + output, plus "_wall_s" for total wall time. + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + out_shape = _output_shape_for(spec.name, sub.name, config) + out_buf = np.zeros(out_shape, dtype=bfloat16) + + # Build the 3-arg list (all standalones have exactly 3 args). + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = out_buf + continue + + # Try static (weight/layer-level) lookup first. + v = _static_input_for(spec.name, sub.name, slot, layer_inputs) + if v is not None: + args[slot] = v + continue + + # Otherwise this slot is fed by an upstream baton link. + for link in spec.baton_links: + if link.consumer_idx == idx and link.consumer_in_slot == slot: + producer_name = spec.sub_launches[link.producer_idx].name + args[slot] = results[producer_name] + break + + assert args[slot] is not None, ( + f"[cell_b] no source found for {spec.name}/{sub.name} slot={slot}. " + f"Check baton_links and _static_input_for." + ) + + # Determine static_input_indices for this sub-launch. + if sub.weight_slot_in_standalone is not None: + static_indices = {sub.weight_slot_in_standalone} + else: + static_indices = set() + + kernel_name = f"{spec.name}__{sub.name}" + bo_key = f"B_{spec.name}_{sub.name}_L{layer_idx}" + + result = cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[sub.output_slot_in_standalone], + static_input_indices=static_indices, + bo_key=bo_key, + ) + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/cell_c_charitable.py b/programming_examples/llama32_1b/ablation/prefill/cells/cell_c_charitable.py new file mode 100644 index 000000000..555066541 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/cell_c_charitable.py @@ -0,0 +1,279 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell C -- Cell B + shared intermediate BOs across separate xrt.run() calls, +parameterized over a KernelGroupSpec. Walks spec.baton_links to alias BOs. + +Two public phases: + + preload_cell_c(cache, spec, weights_per_layer, config, backend_preset) + Called once before timing. For each (layer_idx, layer_weights) pair: + 1. Run each sub-launch once (allocates BOs and writes weights via + static_input_indices). Uses bo_key=f"C_{spec.name}_{sub.name}_L{li}". + 2. Walk spec.baton_links and alias each producer's output BO into + the consumer's input BO slot via _share_bo. + + run_cell_c(cache, spec, layer_inputs, config, backend_preset, layer_idx=0) + Same dataflow as Cell B but with: + - bo_key=f"C_{spec.name}_{sub.name}_L{layer_idx}" (matches preload). + - intermediate_indices: producer output slots and consumer input slots + that are baton-managed (host skips writing those BOs). + +For a baton-aliased slot, a np.zeros placeholder is passed to load_and_run; +the bytes are NOT written to device because the slot is in intermediate_indices. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.cell_a_naive import _output_shape_for, _static_input_for +from cells.common import compile_standalone_kernels, _share_bo + +# --------------------------------------------------------------------------- +# Compile (same registry walk as Cell A / Cell B) +# --------------------------------------------------------------------------- + + +def compile_cell_c(cache, spec, backend_preset): + """Compile the standalone ELFs for this kernel-group into cache.""" + registry = [(s.name, s.builder_ref, s.build_kwargs) for s in spec.sub_launches] + compile_standalone_kernels(cache, spec.name, registry, backend_preset) + + +# --------------------------------------------------------------------------- +# Shape helpers +# --------------------------------------------------------------------------- + + +def _slot_shape_for(spec_name, sub_name, slot, config): + """Return the numpy shape for an arbitrary (sub_name, slot) pair. + + Covers both weight slots and activation/baton slots so that the preload + loop can allocate correctly-sized BOs for all sub-launches, including + those with no weight slot (res_add, swiglu, ffn_add). + + For weight slots this returns the weight shape (2-D for GEMMs, 1-D for + norms/LUTs). For activation/baton slots it returns the activation shape. + """ + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + + if spec_name == "rms_gemms_rope": + # slot 2 = output for every sub-launch; handled by _output_shape_for. + table = { + # slot0 slot1 + "rmsnorm": [(seq, emb), (emb,)], + "q_gemm": [(seq, emb), (emb, emb)], + "k_gemm": [(seq, emb), (emb, kv)], + "v_gemm": [(seq, emb), (emb, kv)], + "rope_q": [(seq, emb), (seq * emb,)], + "rope_k": [(seq, kv), (seq * kv,)], + } + return table[sub_name][slot] + + if spec_name == "o_ffn": + table = { + # slot0 slot1 + "o_gemm": [(seq, emb), (emb, emb)], + "res_add": [(seq, emb), (seq, emb)], + "ffn_rmsnorm": [(seq, emb), (emb,)], + "gate_gemm": [(seq, emb), (emb, hid)], + "up_gemm": [(seq, emb), (emb, hid)], + "swiglu": [(seq, hid), (seq, hid)], + "down_gemm": [(seq, hid), (hid, emb)], + "ffn_add": [(seq, emb), (seq, emb)], + } + return table[sub_name][slot] + + raise ValueError(f"unknown spec {spec_name!r} or sub {sub_name!r}") + + +# --------------------------------------------------------------------------- +# Baton-link helpers +# --------------------------------------------------------------------------- + + +def _intermediate_slots_for_sub(spec, sub_idx): + """For a given sub-launch index, return the set of slots that are + baton-managed (either produced or consumed via a baton link). + + These slots are passed as intermediate_indices to load_and_run so the + host skips writing them: + - Producer output slot: the kernel writes here; downstream reads from the + same BO via the alias. + - Consumer input slot: upstream already wrote to it via the shared BO; + host must not overwrite with zeros. + """ + slots = set() + for link in spec.baton_links: + if link.producer_idx == sub_idx: + slots.add(link.producer_out_slot) + if link.consumer_idx == sub_idx: + slots.add(link.consumer_in_slot) + return slots + + +# --------------------------------------------------------------------------- +# Preload phase +# --------------------------------------------------------------------------- + + +def preload_cell_c(cache, spec, weights_per_layer, config, backend_preset): + """One-shot allocation: run each sub-launch once to materialise BOs, then + alias intermediate BOs across sub-launches per spec.baton_links. + + Phase 1 (inner loop over sub_launches): Each sub-launch is invoked once + with its actual weight in place and dummy zeros for all other inputs. + This causes KernelCache to allocate the BO set for that bo_key. + + Phase 2 (inner loop over baton_links): _share_bo aliases the producer's + output BO into the consumer's input BO slot so that both operations refer + to the same xrt.bo object. + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + for li, layer_weights in enumerate(weights_per_layer): + # --- Phase 1: allocate BOs for every sub-launch --- + for sub in spec.sub_launches: + out_shape = _output_shape_for(spec.name, sub.name, config) + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = np.zeros(out_shape, dtype=bfloat16) + continue + if ( + sub.weight_slot_in_standalone is not None + and slot == sub.weight_slot_in_standalone + ): + # Use the actual weight so the BO is populated from the start. + w = _static_input_for(spec.name, sub.name, slot, layer_weights) + assert w is not None, ( + f"[cell_c preload] _static_input_for returned None for " + f"{spec.name}/{sub.name} slot={slot}" + ) + args[slot] = w + continue + # Activation or baton-fed slot: correctly-sized dummy zeros. + args[slot] = np.zeros( + _slot_shape_for(spec.name, sub.name, slot, config), dtype=bfloat16 + ) + + static_idx = ( + {sub.weight_slot_in_standalone} + if sub.weight_slot_in_standalone is not None + else set() + ) + kernel_name = f"{spec.name}__{sub.name}" + bo_key = f"C_{spec.name}_{sub.name}_L{li}" + + cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[sub.output_slot_in_standalone], + static_input_indices=static_idx, + bo_key=bo_key, + ) + + # --- Phase 2: alias BOs per baton_links --- + for link in spec.baton_links: + producer = spec.sub_launches[link.producer_idx] + consumer = spec.sub_launches[link.consumer_idx] + _share_bo( + cache, + f"C_{spec.name}_{producer.name}_L{li}", + link.producer_out_slot, + f"C_{spec.name}_{consumer.name}_L{li}", + link.consumer_in_slot, + ) + + +# --------------------------------------------------------------------------- +# Timed run phase +# --------------------------------------------------------------------------- + + +def run_cell_c(cache, spec, layer_inputs, config, backend_preset, layer_idx=0): + """Run all spec.sub_launches sequentially with pre-loaded weight BOs and + shared intermediate BOs (baton-pass). + + Differences from Cell B: + - bo_key uses "C_" prefix (matches preload). + - intermediate_indices is set for each sub-launch based on baton_links: + * producer's output slot -> kernel overwrites it; don't host-write + * consumer's input slot -> aliased to upstream BO; don't host-write + + For baton-fed input slots the numpy arg is np.zeros (placeholder); bytes + are skipped because the slot is in intermediate_indices. + + Args: + cache: KernelCache with manifested artifacts (preload must have run). + spec: KernelGroupSpec (rms_gemms_rope or o_ffn). + layer_inputs: dict of numpy arrays keyed by semantic name. + config: dict with seq_len, emb_dim, kv_dim, hidden_dim. + backend_preset: backend kwargs dict (instance_name will be removed). + layer_idx: layer index used to select the right pre-loaded BO set. + + Returns: + dict keyed by sub.name -> 1D flat numpy array of that sub-launch's + output, plus "_wall_s" for total wall time. + """ + backend = {**backend_preset} + backend.pop("instance_name", None) + + results = {} + t0 = time.perf_counter() + + for idx, sub in enumerate(spec.sub_launches): + out_shape = _output_shape_for(spec.name, sub.name, config) + + # Build the 3-arg list. + args = [None, None, None] + + for slot in range(3): + if slot == sub.output_slot_in_standalone: + args[slot] = np.zeros(out_shape, dtype=bfloat16) + continue + + # Try static (weight/LUT/layer-level) lookup first. + v = _static_input_for(spec.name, sub.name, slot, layer_inputs) + if v is not None: + args[slot] = v + continue + + # Baton-fed slot: host won't write it (intermediate_indices); use + # a correctly-sized zero placeholder so the array shape is valid. + args[slot] = np.zeros( + _slot_shape_for(spec.name, sub.name, slot, config), dtype=bfloat16 + ) + + intermediate_idx = _intermediate_slots_for_sub(spec, idx) + static_idx = ( + {sub.weight_slot_in_standalone} + if sub.weight_slot_in_standalone is not None + else set() + ) + + kernel_name = f"{spec.name}__{sub.name}" + bo_key = f"C_{spec.name}_{sub.name}_L{layer_idx}" + + result = cache.load_and_run( + kernel_name, + backend, + *args, + output_indices=[sub.output_slot_in_standalone], + static_input_indices=static_idx, + intermediate_indices=intermediate_idx, + bo_key=bo_key, + ) + results[sub.name] = result[sub.output_slot_in_standalone] + + elapsed = time.perf_counter() - t0 + results["_wall_s"] = elapsed + return results diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/cell_d_merged.py b/programming_examples/llama32_1b/ablation/prefill/cells/cell_d_merged.py new file mode 100644 index 000000000..318cdd958 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/cell_d_merged.py @@ -0,0 +1,151 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Cell D — production: invoke the merged ELFs (rms_gemms_rope.elf with 6 +launches; o_ffn.elf with 8 launches) using the production KernelCache + +backend presets. +""" + +import os +import sys + +# Ensure llama32_1b/ is on sys.path so kernel_builder and multi_launch_builder +# are importable whether this file is run directly or imported from the +# prefill/ package root. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_LLAMA_DIR = os.path.normpath(os.path.join(_THIS_DIR, "..", "..", "..")) +if _LLAMA_DIR not in sys.path: + sys.path.insert(0, _LLAMA_DIR) + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND, O_FFN_BACKEND +from multi_launch_builder.rms_gemms_rope_multi import build_rms_gemms_rope_module +from multi_launch_builder.o_ffn_multi import build_o_ffn_module + +CONFIG = { + "seq_len": 2048, + "emb_dim": 2048, + "kv_dim": 512, + "n_heads": 32, + "n_kv_heads": 8, + "head_dim": 64, + "hidden_dim": 8192, +} + + +def compile_cell_d_rms_gemms_rope(cache: KernelCache): + if "rms_gemms_rope" in cache.artifacts: + return + mod = build_rms_gemms_rope_module( + seq_len=CONFIG["seq_len"], + emb_dim=CONFIG["emb_dim"], + kv_dim=CONFIG["kv_dim"], + n_heads=CONFIG["n_heads"], + n_kv_heads=CONFIG["n_kv_heads"], + head_dim=CONFIG["head_dim"], + ) + cache.compile_and_cache( + "rms_gemms_rope", mod, {"verbose": cache.verbose, **RMS_GEMMS_ROPE_BACKEND} + ) + cache._save_manifest() + + +def compile_cell_d_o_ffn(cache: KernelCache): + if "o_ffn" in cache.artifacts: + return + mod = build_o_ffn_module( + seq_len=CONFIG["seq_len"], + emb_dim=CONFIG["emb_dim"], + hidden_dim=CONFIG["hidden_dim"], + ) + cache.compile_and_cache("o_ffn", mod, {"verbose": cache.verbose, **O_FFN_BACKEND}) + cache._save_manifest() + + +def run_cell_d_rms_gemms_rope(cache, layer_inputs, layer_idx=0): + """One rms_gemms_rope call (6 launches in one xrt.run). + layer_inputs has keys: x_in, norm_w, wq, wk, wv, lut_q, lut_k. + Returns dict with normed, q, k, v, q_roped, k_roped, _wall_s. + """ + seq = CONFIG["seq_len"] + emb = CONFIG["emb_dim"] + kv = CONFIG["kv_dim"] + args = [ + layer_inputs["x_in"], + layer_inputs["norm_w"], + np.zeros((seq, emb), dtype=bfloat16), # normed + layer_inputs["wq"], + np.zeros((seq, emb), dtype=bfloat16), # q + layer_inputs["wk"], + np.zeros((seq, kv), dtype=bfloat16), # k + layer_inputs["wv"], + np.zeros((seq, kv), dtype=bfloat16), # v + layer_inputs["lut_q"], + layer_inputs["lut_k"], + np.zeros((seq, emb), dtype=bfloat16), # q_roped + np.zeros((seq, kv), dtype=bfloat16), # k_roped + ] + t0 = time.perf_counter() + out = cache.load_and_run( + "rms_gemms_rope", + RMS_GEMMS_ROPE_BACKEND, + *args, + output_indices=[2, 4, 6, 8, 11, 12], + static_input_indices={1, 3, 5, 7, 9, 10}, + intermediate_indices={2, 4, 6, 8, 11, 12}, + bo_key=f"D_rms_gemms_rope_L{layer_idx}", + ) + elapsed = time.perf_counter() - t0 + return { + "normed": out[2], + "q": out[4], + "k": out[6], + "v": out[8], + "q_roped": out[11], + "k_roped": out[12], + "_wall_s": elapsed, + } + + +def run_cell_d_o_ffn(cache, layer_inputs, layer_idx=0): + """One o_ffn call (8 launches in one xrt.run). + layer_inputs has: attn_out, wo, x_residual, ffn_norm_w, w_gate, w_up, w_down. + Returns dict with output, _wall_s. + """ + seq = CONFIG["seq_len"] + emb = CONFIG["emb_dim"] + hid = CONFIG["hidden_dim"] + n_total = seq * emb + args = [ + layer_inputs["attn_out"], # 0 + layer_inputs["wo"], # 1 + np.zeros((seq, emb), dtype=bfloat16), # 2 proj + layer_inputs["x_residual"], # 3 + np.zeros((seq, emb), dtype=bfloat16), # 4 res1 + layer_inputs["ffn_norm_w"], # 5 + np.zeros((seq, emb), dtype=bfloat16), # 6 normed2 + layer_inputs["w_gate"], # 7 + np.zeros((seq, hid), dtype=bfloat16), # 8 gate + layer_inputs["w_up"], # 9 + np.zeros((seq, hid), dtype=bfloat16), # 10 up + np.zeros((seq, hid), dtype=bfloat16), # 11 swiglu + layer_inputs["w_down"], # 12 + np.zeros((seq, emb), dtype=bfloat16), # 13 down + np.zeros(n_total, dtype=bfloat16), # 14 output (1D) + ] + t0 = time.perf_counter() + out = cache.load_and_run( + "o_ffn", + O_FFN_BACKEND, + *args, + output_indices=[14], + static_input_indices={1, 5, 7, 9, 12}, + intermediate_indices={2, 4, 6, 8, 10, 11, 13, 14}, + bo_key=f"D_o_ffn_L{layer_idx}", + ) + return {"output": out[14], "_wall_s": time.perf_counter() - t0} diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/common.py b/programming_examples/llama32_1b/ablation/prefill/cells/common.py new file mode 100644 index 000000000..82992bfb1 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/common.py @@ -0,0 +1,84 @@ +"""Shared helpers for prefill ablation cells. + +Lifted (and extended for two-backend support) from Plan 1's +ablation/cells/common.py. The original Plan 1 file is read-only. + +- compile_standalone_kernels(cache, group_name, registry, backend_preset): + Compile every standalone in `registry` into `cache`, using the actual + public func name extracted from the MLIR module as instance_name. +- _extract_public_func_name(mlir_text): regex over the module string. +- _share_bo(cache, src_key, src_slot, dst_key, dst_slot): alias cached BOs + for Cell C's baton-pass. +- standalone_backend_kwargs(backend_preset, verbose): returns backend kwargs + with instance_name removed (set per-kernel by compile_standalone_kernels). +""" + +import re + +from air.ir import Context as MLIRContext + +from kernel_builder.cache import KernelCache + + +def _extract_public_func_name(mlir_text): + """Find the first non-private `func.func @` in the module text.""" + for line in mlir_text.split("\n"): + if "func.func @" in line and "private" not in line: + m = re.search(r"@(\w+)", line) + if m: + return m.group(1) + raise ValueError("no public func.func found in module") + + +def standalone_backend_kwargs(backend_preset, verbose=False): + """Backend kwargs with instance_name removed (set per-kernel by caller).""" + base = {**backend_preset, "verbose": verbose} + base.pop("instance_name", None) + return base + + +def compile_standalone_kernels( + cache: KernelCache, group_name: str, registry, backend_preset +): + """Compile every standalone in `registry` into `cache` under names + f"{group_name}__{name}". Skip any kernel already in cache.artifacts. + + Each registry entry: (name, build_fn, build_kwargs). + """ + for name, build_fn, kwargs in registry: + kernel_name = f"{group_name}__{name}" + if kernel_name in cache.artifacts: + continue + with MLIRContext(): + mlir_module = build_fn(**kwargs) + public_func = _extract_public_func_name(str(mlir_module)) + be = standalone_backend_kwargs(backend_preset, verbose=cache.verbose) + be["instance_name"] = public_func + cache.compile_and_cache(kernel_name, mlir_module, be) + cache._save_manifest() + + +def _share_bo(cache, src_key, src_slot, dst_key, dst_slot): + """Replace cached BO at (dst_key, dst_slot) with the same xrt.bo as + (src_key, src_slot). Only valid after both kernels' first call has + materialized BOs.""" + src_bos = cache._cached_bos[src_key] + dst_bos = cache._cached_bos[dst_key] + dst_bos[dst_slot] = src_bos[src_slot] + + +def main(): + """python3 -m cells.common — compile both kernel-groups' standalones.""" + from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND, O_FFN_BACKEND + from standalone_builders.rms_gemms_rope import STANDALONES as RMS_STD + from standalone_builders.o_ffn import STANDALONES as O_STD + + cache = KernelCache(cache_dir="standalone_cache", verbose=True) + cache.load_manifest() + compile_standalone_kernels(cache, "rms_gemms_rope", RMS_STD, RMS_GEMMS_ROPE_BACKEND) + compile_standalone_kernels(cache, "o_ffn", O_STD, O_FFN_BACKEND) + print(f"Compiled {len(cache.artifacts)} standalone ELFs.") + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/flash_attn_const.py b/programming_examples/llama32_1b/ablation/prefill/cells/flash_attn_const.py new file mode 100644 index 000000000..4f1b0f411 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/flash_attn_const.py @@ -0,0 +1,74 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""FlashAttention invariant: same standalone ELF + same invocation in every cell. + +FA's MLIR builder is at programming_examples/flash_attention/kernel_fusion_based/attn_npu2_seqfirst.py +with kwargs matching Plan 1's compile_all_kernels() in llama32_1b_prefill.py. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache + + +def _attn_backend_kwargs(): + return { + "verbose": False, + "omit_while_true_loop": False, # head_dim=64, lkp=64 enables shared buffers + "omit_pingpong": "all", + "runtime_loop_tiling_sizes": [1, 1], + "output_format": "elf", + "instance_name": "attention_bf16", + } + + +def compile_flash_attn(cache: KernelCache, config): + """Compile FA ELF if not already cached. ~46s first time per profile.md.""" + if "flash_attn" in cache.artifacts: + return + from flash_attention.kernel_fusion_based.attn_npu2_seqfirst import ( + build_module as build_attn, + ) + + seq = config["seq_len"] + head_dim = config["head_dim"] + n_heads = config["n_heads"] + n_kv_heads = config["n_kv_heads"] + mod = build_attn( + lk=seq, + lkp=head_dim, + lq=seq, + lqp=256, + dk=head_dim, + dv=head_dim, + num_q_tiles=4, + num_cascade_stages=4, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + causal=True, + ) + cache.compile_and_cache("flash_attn", mod, _attn_backend_kwargs()) + cache._save_manifest() + + +def run_flash_attn(cache, q_roped, k_roped, v, layer_idx=0): + """Run FA on extracted q_roped/k_roped/v from rms_gemms_rope. + Returns attn_out (extracted to host) ready to feed o_ffn. + """ + seq = q_roped.shape[0] + emb = q_roped.shape[1] + args = [q_roped, k_roped, v, np.zeros((seq, emb), dtype=bfloat16)] + t0 = time.perf_counter() + out = cache.load_and_run( + "flash_attn", + _attn_backend_kwargs(), + *args, + output_indices=[3], + intermediate_indices={3}, + bo_key=f"FA_L{layer_idx}", + ) + return {"attn_out": out[3], "_wall_s": time.perf_counter() - t0} diff --git a/programming_examples/llama32_1b/ablation/prefill/cells/multi_layer.py b/programming_examples/llama32_1b/ablation/prefill/cells/multi_layer.py new file mode 100644 index 000000000..68585cb42 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/cells/multi_layer.py @@ -0,0 +1,86 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""16-layer prefill wrapper. + +Threads: rms_gemms_rope[L] -> FA[L] -> o_ffn[L] -> rms_gemms_rope[L+1] + +The cell-A/B/C/D dispatch strategy is independent of this wrapper; we +take the cell's per-kernel-group runner as a parameter. +""" + +import time + +import numpy as np +from ml_dtypes import bfloat16 + +from cells.flash_attn_const import run_flash_attn + + +def run_16_layer_prefill( + cache, + config, + run_rms_gemms_rope, + run_o_ffn, + layer_inputs_per_layer, +): + """Run a 16-layer prefill via the supplied per-kernel-group runners. + + Args: + cache: shared KernelCache (FA + both groups + standalones all reside here) + config: dict from cell_d_merged.CONFIG + run_rms_gemms_rope(cache, layer_inputs, layer_idx) -> {normed,q,k,v,q_roped,k_roped, _wall_s} + run_o_ffn(cache, layer_inputs, layer_idx) -> {output, _wall_s} + layer_inputs_per_layer: list of N dicts, each with all per-layer weights+LUTs+x_in[layer 0 only] + + Returns dict with: + per_layer_wall: list of N floats (wall time per layer including FA) + total_wall: float + final_output: numpy array (last layer's o_ffn output, reshaped to (seq, emb)) + """ + n_layers = len(layer_inputs_per_layer) + per_layer_wall = [] + x_in = layer_inputs_per_layer[0]["x_in"] + final_output = None + + t_total_start = time.perf_counter() + for L in range(n_layers): + layer_in = dict(layer_inputs_per_layer[L]) + layer_in["x_in"] = x_in # threaded from previous layer + + t_layer_start = time.perf_counter() + + # 1. rms_gemms_rope + rg_out = run_rms_gemms_rope(cache, layer_in, layer_idx=L) + # 2. FA (invariant) + # rms_gemms_rope returns 1D flat arrays; FA expects 2D (seq, dim) + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + q_roped_2d = rg_out["q_roped"].reshape(seq, emb) + k_roped_2d = rg_out["k_roped"].reshape(seq, kv) + v_2d = rg_out["v"].reshape(seq, kv) + fa_out = run_flash_attn(cache, q_roped_2d, k_roped_2d, v_2d, layer_idx=L) + # 3. o_ffn — assemble inputs + of_in = { + "attn_out": fa_out["attn_out"], + "wo": layer_in["wo"], + "x_residual": x_in, + "ffn_norm_w": layer_in["ffn_norm_w"], + "w_gate": layer_in["w_gate"], + "w_up": layer_in["w_up"], + "w_down": layer_in["w_down"], + } + of_out = run_o_ffn(cache, of_in, layer_idx=L) + # The o_ffn output (slot 14) is 1D (n_total = seq*emb); reshape for next layer + x_in = of_out["output"].reshape(config["seq_len"], config["emb_dim"]) + final_output = x_in + + per_layer_wall.append(time.perf_counter() - t_layer_start) + + total_wall = time.perf_counter() - t_total_start + return { + "per_layer_wall": per_layer_wall, + "total_wall": total_wall, + "final_output": final_output, + } diff --git a/programming_examples/llama32_1b/ablation/prefill/golden/__init__.py b/programming_examples/llama32_1b/ablation/prefill/golden/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/prefill/golden/golden_meta.json b/programming_examples/llama32_1b/ablation/prefill/golden/golden_meta.json new file mode 100644 index 000000000..f21aadddd --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/golden/golden_meta.json @@ -0,0 +1,44 @@ +{ + "config": { + "seq_len": 2048, + "emb_dim": 2048, + "kv_dim": 512, + "n_heads": 32, + "n_kv_heads": 8, + "head_dim": 64, + "hidden_dim": 8192 + }, + "rms_gemms_rope": { + "input_hashes": { + "x_in": "fcbc90cb84de3921", + "norm_w": "2b68a598666f46b7", + "wq": "644b193c8ad8deb2", + "wk": "d99f752b4ef2e7cb", + "wv": "170cf86e99d6e81c", + "lut_q": "ea89e3700fc1f79c", + "lut_k": "1af9035ca8e4cb69" + }, + "output_hashes": { + "normed": "97c83313d0086b24", + "q": "841e787880869d03", + "k": "970a6cbd94eed6fd", + "v": "a9a28b1b08840976", + "q_roped": "0bc1552da337d5e2", + "k_roped": "b53a3553b0c34dbb" + } + }, + "o_ffn": { + "input_hashes": { + "attn_out": "c142255ffc76363f", + "wo": "f79d9f01ecb1f849", + "x_residual": "fcbc90cb84de3921", + "ffn_norm_w": "662073a56ab4cafe", + "w_gate": "ae0272f05a315b90", + "w_up": "f16ac32ad33c9d4a", + "w_down": "3017d3b502e1c327" + }, + "output_hashes": { + "output": "c87c94798ef2a94b" + } + } +} \ No newline at end of file diff --git a/programming_examples/llama32_1b/ablation/prefill/golden/golden_o_ffn_prefill.npz b/programming_examples/llama32_1b/ablation/prefill/golden/golden_o_ffn_prefill.npz new file mode 100644 index 000000000..ae6d75f8f Binary files /dev/null and b/programming_examples/llama32_1b/ablation/prefill/golden/golden_o_ffn_prefill.npz differ diff --git a/programming_examples/llama32_1b/ablation/prefill/golden/golden_rms_gemms_rope_prefill.npz b/programming_examples/llama32_1b/ablation/prefill/golden/golden_rms_gemms_rope_prefill.npz new file mode 100644 index 000000000..3143ae50a Binary files /dev/null and b/programming_examples/llama32_1b/ablation/prefill/golden/golden_rms_gemms_rope_prefill.npz differ diff --git a/programming_examples/llama32_1b/ablation/prefill/golden/regen_golden.py b/programming_examples/llama32_1b/ablation/prefill/golden/regen_golden.py new file mode 100644 index 000000000..07127fffe --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/golden/regen_golden.py @@ -0,0 +1,130 @@ +"""Regenerate prefill golden fixtures by running Cell D once for each kernel-group. + +Uses deterministic synthetic inputs (numpy seed=42 for layer 0). +Outputs: + golden/golden_rms_gemms_rope_prefill.npz + golden/golden_o_ffn_prefill.npz + golden/golden_meta.json +""" + +import hashlib +import json +import os +import sys + +import numpy as np +from ml_dtypes import bfloat16 + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from kernel_builder.cache import KernelCache +from cells.cell_d_merged import ( + CONFIG, + compile_cell_d_rms_gemms_rope, + compile_cell_d_o_ffn, + run_cell_d_rms_gemms_rope, + run_cell_d_o_ffn, +) + + +def _synthetic_layer_inputs(layer_idx, config): + """Deterministic synthetic inputs for one prefill layer (seq=2048). + + Same seeding scheme as Plan 1: seed = 42 + layer_idx. + """ + rng = np.random.default_rng(42 + layer_idx) + seq = config["seq_len"] + emb = config["emb_dim"] + kv = config["kv_dim"] + hid = config["hidden_dim"] + return { + "x_in": rng.standard_normal((seq, emb)).astype(bfloat16), + "norm_w": rng.standard_normal(emb).astype(bfloat16), + "wq": rng.standard_normal((emb, emb)).astype(bfloat16), + "wk": rng.standard_normal((emb, kv)).astype(bfloat16), + "wv": rng.standard_normal((emb, kv)).astype(bfloat16), + "lut_q": rng.standard_normal(seq * emb).astype(bfloat16), + "lut_k": rng.standard_normal(seq * kv).astype(bfloat16), + "wo": rng.standard_normal((emb, emb)).astype(bfloat16), + "ffn_norm_w": rng.standard_normal(emb).astype(bfloat16), + "w_gate": rng.standard_normal((emb, hid)).astype(bfloat16), + "w_up": rng.standard_normal((emb, hid)).astype(bfloat16), + "w_down": rng.standard_normal((hid, emb)).astype(bfloat16), + } + + +def main(): + cache = KernelCache(cache_dir="standalone_cache", verbose=True) + cache.load_manifest() + compile_cell_d_rms_gemms_rope(cache) + compile_cell_d_o_ffn(cache) + + inputs = _synthetic_layer_inputs(0, CONFIG) + + # rms_gemms_rope golden + rg_inputs = { + k: inputs[k] for k in ["x_in", "norm_w", "wq", "wk", "wv", "lut_q", "lut_k"] + } + rg_out = run_cell_d_rms_gemms_rope(cache, rg_inputs, layer_idx=0) + rg_path = os.path.join( + os.path.dirname(__file__), "golden_rms_gemms_rope_prefill.npz" + ) + np.savez(rg_path, **{k: v for k, v in rg_out.items() if not k.startswith("_")}) + + # For o_ffn golden, attn_out comes from FA in production. For the golden + # we use a CPU FA reference computed from rg_out's q_roped/k_roped/v — + # since FA is invariant across cells, all cells will see the same attn_out. + # Simplest: synthesize attn_out from the same RNG (it is what flows into + # o_ffn's slot 0 in every cell; the bytes are determined upstream). + attn_out = ( + np.random.default_rng(42 + 0 + 1000) + .standard_normal((CONFIG["seq_len"], CONFIG["emb_dim"])) + .astype(bfloat16) + ) + of_inputs = { + "attn_out": attn_out, + "wo": inputs["wo"], + "x_residual": inputs["x_in"], # the residual is the layer input + "ffn_norm_w": inputs["ffn_norm_w"], + "w_gate": inputs["w_gate"], + "w_up": inputs["w_up"], + "w_down": inputs["w_down"], + } + of_out = run_cell_d_o_ffn(cache, of_inputs, layer_idx=0) + of_path = os.path.join(os.path.dirname(__file__), "golden_o_ffn_prefill.npz") + np.savez(of_path, **{k: v for k, v in of_out.items() if not k.startswith("_")}) + + meta = { + "config": CONFIG, + "rms_gemms_rope": { + "input_hashes": { + k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in rg_inputs.items() + }, + "output_hashes": { + k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in rg_out.items() + if not k.startswith("_") + }, + }, + "o_ffn": { + "input_hashes": { + k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in of_inputs.items() + }, + "output_hashes": { + k: hashlib.sha256(v.tobytes()).hexdigest()[:16] + for k, v in of_out.items() + if not k.startswith("_") + }, + }, + } + with open(os.path.join(os.path.dirname(__file__), "golden_meta.json"), "w") as f: + json.dump(meta, f, indent=2) + print(f"Wrote {rg_path}, {of_path}, golden_meta.json") + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/prefill/run_ablation.py b/programming_examples/llama32_1b/ablation/prefill/run_ablation.py new file mode 100644 index 000000000..1eb006e48 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/run_ablation.py @@ -0,0 +1,480 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Run the prefill 4-cell ablation. + +Modes: + --scope=single-layer 5 trials × 1-layer cell call (per kernel-group) + --scope=16-layer 5 trials × 16-layer triple (rms->FA->o_ffn) loop + --scope=both (default) both above + +Run from programming_examples/llama32_1b/ablation/prefill/build/ +(where standalone_cache/ lives and xclbins are found). +""" + +import argparse +import json +import os +import sys +import time + +# Path setup: this script lives in prefill/; CWD is build/ (where standalone_cache/ lives) +# prefill/ -> ablation/ -> llama32_1b/ -> programming_examples/ +_PREFILL = os.path.dirname(os.path.abspath(__file__)) +_ABLATION = os.path.dirname(_PREFILL) +_LLAMA = os.path.dirname(_ABLATION) +_PROG_EXAMPLES = os.path.dirname(_LLAMA) + +# Insert in ascending priority: _PROG_EXAMPLES appended, _PREFILL at front. +# Use append for lower-priority dirs so they don't shadow prefill's 'cells' package. +for p in (_PROG_EXAMPLES, _LLAMA, _ABLATION): + if p not in sys.path: + sys.path.append(p) +# _PREFILL must be at index 0 so prefill/cells/ wins over ablation/cells/. +if _PREFILL in sys.path: + sys.path.remove(_PREFILL) +sys.path.insert(0, _PREFILL) + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from kernel_builder.backend_presets import RMS_GEMMS_ROPE_BACKEND, O_FFN_BACKEND + +from validate import validate_against_golden, GoldenMismatch +from cells import cell_a_naive, cell_b_static, cell_c_charitable, cell_d_merged +from cells.flash_attn_const import compile_flash_attn +from cells.multi_layer import run_16_layer_prefill +from specs.rms_gemms_rope import SPEC as RG_SPEC +from specs.o_ffn import SPEC as OF_SPEC +from golden.regen_golden import _synthetic_layer_inputs + +GOLDEN_DIR = os.path.join(_PREFILL, "golden") + + +# --------------------------------------------------------------------------- +# Output key adapters: convert cell A/B/C sub-launch dicts to golden-comparable +# --------------------------------------------------------------------------- + + +def _rg_cell_outputs(out, cell): + """Map run_cell_* output dict to golden keys for rms_gemms_rope.""" + if cell == "D": + # Cell D already returns {normed, q, k, v, q_roped, k_roped, _wall_s} + return {k: v for k, v in out.items() if not k.startswith("_")} + # Cell A/B/C: sub-launch names as keys + return { + "normed": out["rmsnorm"], + "q": out["q_gemm"], + "k": out["k_gemm"], + "v": out["v_gemm"], + "q_roped": out["rope_q"], + "k_roped": out["rope_k"], + } + + +def _of_cell_outputs(out, cell): + """Map run_cell_* output dict to golden keys for o_ffn.""" + if cell == "D": + # Cell D returns {output, _wall_s} + return {"output": out["output"]} + # Cell A/B/C: last sub-launch is "ffn_add"; golden only checks "output" + return {"output": out["ffn_add"].reshape(-1)} + + +# --------------------------------------------------------------------------- +# Cell runners (single-layer) — unified interface +# --------------------------------------------------------------------------- + + +def _run_rg(cell, cache, layer_inputs): + """Run rms_gemms_rope for the given cell. Returns raw output dict.""" + if cell == "A": + return cell_a_naive.run_cell_a( + cache, RG_SPEC, layer_inputs, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + if cell == "B": + return cell_b_static.run_cell_b( + cache, RG_SPEC, layer_inputs, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + if cell == "C": + return cell_c_charitable.run_cell_c( + cache, RG_SPEC, layer_inputs, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + if cell == "D": + rg_in = { + k: layer_inputs[k] + for k in ["x_in", "norm_w", "wq", "wk", "wv", "lut_q", "lut_k"] + } + return cell_d_merged.run_cell_d_rms_gemms_rope(cache, rg_in) + raise ValueError(f"unknown cell {cell!r}") + + +def _run_of(cell, cache, layer_inputs): + """Run o_ffn for the given cell. Returns raw output dict. + + layer_inputs must contain: attn_out, wo, x_residual, ffn_norm_w, + w_gate, w_up, w_down (plus any extra keys ignored by A/B/C). + """ + if cell == "A": + return cell_a_naive.run_cell_a( + cache, OF_SPEC, layer_inputs, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + if cell == "B": + return cell_b_static.run_cell_b( + cache, OF_SPEC, layer_inputs, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + if cell == "C": + return cell_c_charitable.run_cell_c( + cache, OF_SPEC, layer_inputs, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + if cell == "D": + of_in = { + k: layer_inputs[k] + for k in [ + "attn_out", + "wo", + "x_residual", + "ffn_norm_w", + "w_gate", + "w_up", + "w_down", + ] + } + return cell_d_merged.run_cell_d_o_ffn(cache, of_in) + raise ValueError(f"unknown cell {cell!r}") + + +# --------------------------------------------------------------------------- +# 16-layer adapter: convert cell A/B/C output to multi_layer-expected shape +# --------------------------------------------------------------------------- + + +def _make_rg_runner_16layer(cell, cache): + """Return a run_rms_gemms_rope(cache, layer_in, layer_idx) adapter for multi_layer. + + multi_layer.py expects the function to return a dict with keys: + q_roped, k_roped, v (and others, unused by multi_layer) + all as 1D flat arrays (it reshapes them internally before calling FA). + """ + + def run(c, layer_in, layer_idx=0): + if cell in ("A", "B", "C"): + out = _run_rg(cell, c, layer_in) + # Convert sub-launch names to canonical names for multi_layer + out["q_roped"] = out["rope_q"] + out["k_roped"] = out["rope_k"] + out["q"] = out["q_gemm"] + out["k"] = out["k_gemm"] + out["v"] = out["v_gemm"] + out["normed"] = out["rmsnorm"] + else: + out = _run_rg(cell, c, layer_in) + return out + + return run + + +def _make_of_runner_16layer(cell, cache): + """Return a run_o_ffn(cache, of_in, layer_idx) adapter for multi_layer. + + multi_layer.py assembles of_in with all needed keys (attn_out, wo, + x_residual, ffn_norm_w, w_gate, w_up, w_down) and calls this. + We need to return a dict with key 'output' as a 1D array that multi_layer + reshapes for the next layer's x_in. + """ + + def run(c, of_in, layer_idx=0): + out = _run_of(cell, c, of_in) + if cell in ("A", "B", "C"): + # Rename ffn_add -> output for multi_layer compatibility + out["output"] = out["ffn_add"].reshape(-1) + return out + + return run + + +# --------------------------------------------------------------------------- +# Context management +# --------------------------------------------------------------------------- + + +def _unload_all_contexts(cache): + """Unload all XRT HW contexts and drop all cached BOs. + + The NPU has a limited number of HW context slots (~16). When switching + between single-layer (14+ standalone contexts) and 16-layer (up to 15 + contexts for Cell A/B/C), we must release all contexts first to avoid + hitting the limit. + + BOs are allocated against a specific XRT device handle; after unloading + the backend that handle is nulled, so the old BO objects are unusable. + We must also clear _cached_bos so the next load_and_run allocates fresh + BOs against the new device. This means preloaded Cell B/C weights are + lost and will be re-written on the next call (acceptable since the + 16-layer loop only runs one cell at a time anyway). + """ + for name, (backend, _) in list(cache._loaded.items()): + try: + backend.unload() + except Exception: + pass + cache._loaded.clear() + cache._cached_bos.clear() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--trials", type=int, default=5) + ap.add_argument( + "--scope", + choices=["single-layer", "16-layer", "both"], + default="both", + ) + ap.add_argument("--out", default=None) + args = ap.parse_args() + + cache = KernelCache(cache_dir="standalone_cache", verbose=False) + cache.load_manifest() + + # ---- Compile all cells + FA (idempotent — skips if already cached) ---- + print("=== Compiling kernels (idempotent) ===") + cell_a_naive.compile_cell_a(cache, RG_SPEC, RMS_GEMMS_ROPE_BACKEND) + cell_a_naive.compile_cell_a(cache, OF_SPEC, O_FFN_BACKEND) + cell_b_static.compile_cell_b(cache, RG_SPEC, RMS_GEMMS_ROPE_BACKEND) + cell_b_static.compile_cell_b(cache, OF_SPEC, O_FFN_BACKEND) + cell_c_charitable.compile_cell_c(cache, RG_SPEC, RMS_GEMMS_ROPE_BACKEND) + cell_c_charitable.compile_cell_c(cache, OF_SPEC, O_FFN_BACKEND) + cell_d_merged.compile_cell_d_rms_gemms_rope(cache) + cell_d_merged.compile_cell_d_o_ffn(cache) + compile_flash_attn(cache, cell_d_merged.CONFIG) + print("All kernels compiled/cached.\n") + + # ---- Generate per-layer synthetic inputs (all 16 layers) ---- + layer_inputs_per_layer = [ + _synthetic_layer_inputs(L, cell_d_merged.CONFIG) for L in range(16) + ] + + # ---- Pre-load weights for Cell B and Cell C (both kernel-groups, all 16 layers) ---- + print("=== Pre-loading weights for Cell B and Cell C ===") + rg_weights = [ + {k: li[k] for k in ["norm_w", "wq", "wk", "wv", "lut_q", "lut_k"]} + for li in layer_inputs_per_layer + ] + of_weights = [ + {k: li[k] for k in ["wo", "ffn_norm_w", "w_gate", "w_up", "w_down"]} + for li in layer_inputs_per_layer + ] + + cell_b_static.preload_cell_b( + cache, RG_SPEC, rg_weights, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + cell_b_static.preload_cell_b( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + cell_c_charitable.preload_cell_c( + cache, RG_SPEC, rg_weights, cell_d_merged.CONFIG, RMS_GEMMS_ROPE_BACKEND + ) + cell_c_charitable.preload_cell_c( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + print("Preload done.\n") + + results = { + "config": cell_d_merged.CONFIG, + "trials": args.trials, + "scope": args.scope, + "cells": {}, + } + + # ---- Build layer-0 inputs for single-layer validation and timing ---- + layer0 = layer_inputs_per_layer[0] + # o_ffn needs attn_out (from FA in production; synthesized here to match regen_golden) + attn_out_layer0 = ( + np.random.default_rng(42 + 0 + 1000) + .standard_normal( + (cell_d_merged.CONFIG["seq_len"], cell_d_merged.CONFIG["emb_dim"]) + ) + .astype(bfloat16) + ) + of_layer0 = dict(layer0) + of_layer0["attn_out"] = attn_out_layer0 + of_layer0["x_residual"] = layer0["x_in"] + + # ---- Validation: single-layer Cell A/B/C/D vs both goldens ---- + print("=== Validation (layer 0, single-layer) ===") + for cell in ("A", "B", "C", "D"): + cell_results = {} + + # rms_gemms_rope validation + try: + rg_out = _run_rg(cell, cache, layer0) + rg_cell_out = _rg_cell_outputs(rg_out, cell) + validate_against_golden( + rg_cell_out, GOLDEN_DIR, "golden_rms_gemms_rope_prefill.npz" + ) + cell_results["rms_gemms_rope"] = {"validation": "PASS"} + print(f" Cell {cell} rms_gemms_rope: PASS") + except GoldenMismatch as e: + cell_results["rms_gemms_rope"] = {"validation": "FAIL", "error": str(e)} + print(f" Cell {cell} rms_gemms_rope: FAIL - {e}") + except Exception as e: + cell_results["rms_gemms_rope"] = {"validation": "ERROR", "error": str(e)} + print(f" Cell {cell} rms_gemms_rope: ERROR - {e}") + + # o_ffn validation + try: + of_out = _run_of(cell, cache, of_layer0) + of_cell_out = _of_cell_outputs(of_out, cell) + validate_against_golden(of_cell_out, GOLDEN_DIR, "golden_o_ffn_prefill.npz") + cell_results["o_ffn"] = {"validation": "PASS"} + print(f" Cell {cell} o_ffn: PASS") + except GoldenMismatch as e: + cell_results["o_ffn"] = {"validation": "FAIL", "error": str(e)} + print(f" Cell {cell} o_ffn: FAIL - {e}") + except Exception as e: + cell_results["o_ffn"] = {"validation": "ERROR", "error": str(e)} + print(f" Cell {cell} o_ffn: ERROR - {e}") + + results["cells"][cell] = cell_results + + print() + + # ---- Timing: single-layer scope ---- + if args.scope in ("single-layer", "both"): + print("=== Timing: single-layer scope ===") + for cell in ("A", "B", "C", "D"): + cr = results["cells"][cell] + + # rms_gemms_rope timing + if cr.get("rms_gemms_rope", {}).get("validation") == "PASS": + times_rg = [] + for _ in range(args.trials): + o = _run_rg(cell, cache, layer0) + times_rg.append(o["_wall_s"]) + keep = sorted(times_rg[1:]) + med_rg = keep[len(keep) // 2] + cr["rms_gemms_rope"]["single_layer"] = { + "all_trials_s": times_rg, + "median_s": med_rg, + "min_s": min(keep), + "max_s": max(keep), + } + print( + f" Cell {cell} rg single-layer: " + f"med={med_rg * 1000:.2f}ms " + f"[{min(keep)*1000:.2f}-{max(keep)*1000:.2f}ms] " + f"(warmup={times_rg[0]*1000:.2f}ms)" + ) + + # o_ffn timing + if cr.get("o_ffn", {}).get("validation") == "PASS": + times_of = [] + for _ in range(args.trials): + o = _run_of(cell, cache, of_layer0) + times_of.append(o["_wall_s"]) + keep = sorted(times_of[1:]) + med_of = keep[len(keep) // 2] + cr["o_ffn"]["single_layer"] = { + "all_trials_s": times_of, + "median_s": med_of, + "min_s": min(keep), + "max_s": max(keep), + } + print( + f" Cell {cell} of single-layer: " + f"med={med_of * 1000:.2f}ms " + f"[{min(keep)*1000:.2f}-{max(keep)*1000:.2f}ms] " + f"(warmup={times_of[0]*1000:.2f}ms)" + ) + print() + + # ---- Timing: 16-layer scope ---- + if args.scope in ("16-layer", "both"): + print("=== Timing: 16-layer scope ===") + for cell in ("A", "B", "C", "D"): + cr = results["cells"][cell] + rg_ok = cr.get("rms_gemms_rope", {}).get("validation") == "PASS" + of_ok = cr.get("o_ffn", {}).get("validation") == "PASS" + if not (rg_ok and of_ok): + print( + f" Cell {cell}: skipping 16-layer (validation failed for " + f"{'rms_gemms_rope' if not rg_ok else 'o_ffn'})" + ) + continue + + # Unload all previously opened XRT contexts and BOs before each + # cell's 16-layer run. The NPU has ~16 HW context slots; Cell A/B/C + # each need 14 standalone contexts + FA = 15 total. Starting fresh + # per cell avoids hitting the limit. + # Cell B/C weights are lost with the BOs — re-preload them below. + _unload_all_contexts(cache) + + # Re-preload weights for B and C after the context reset. + if cell == "B": + cell_b_static.preload_cell_b( + cache, + RG_SPEC, + rg_weights, + cell_d_merged.CONFIG, + RMS_GEMMS_ROPE_BACKEND, + ) + cell_b_static.preload_cell_b( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + elif cell == "C": + cell_c_charitable.preload_cell_c( + cache, + RG_SPEC, + rg_weights, + cell_d_merged.CONFIG, + RMS_GEMMS_ROPE_BACKEND, + ) + cell_c_charitable.preload_cell_c( + cache, OF_SPEC, of_weights, cell_d_merged.CONFIG, O_FFN_BACKEND + ) + + run_rg_16 = _make_rg_runner_16layer(cell, cache) + run_of_16 = _make_of_runner_16layer(cell, cache) + + times_total = [] + for trial in range(args.trials): + r = run_16_layer_prefill( + cache, + cell_d_merged.CONFIG, + run_rg_16, + run_of_16, + layer_inputs_per_layer, + ) + times_total.append(r["total_wall"]) + + keep = sorted(times_total[1:]) + med = keep[len(keep) // 2] + cr["16_layer"] = { + "all_trials_s": times_total, + "median_s": med, + "min_s": min(keep), + "max_s": max(keep), + } + print( + f" Cell {cell} 16-layer total: " + f"med={med:.3f}s " + f"[{min(keep):.3f}-{max(keep):.3f}s] " + f"(warmup={times_total[0]:.3f}s)" + ) + print() + + # ---- Dump JSON ---- + out_path = args.out or f"results_prefill_{int(time.time())}.json" + with open(out_path, "w") as f: + json.dump(results, f, indent=2) + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/programming_examples/llama32_1b/ablation/prefill/specs/__init__.py b/programming_examples/llama32_1b/ablation/prefill/specs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/prefill/specs/kernel_group.py b/programming_examples/llama32_1b/ablation/prefill/specs/kernel_group.py new file mode 100644 index 000000000..8ae2f0bf8 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/specs/kernel_group.py @@ -0,0 +1,72 @@ +"""Frozen dataclasses describing a multi-launch kernel-group's structure. + +A KernelGroupSpec is consumed by parameterized cells (cell_a/b/c/d) so that +the same cell logic works for any kernel-group whose spec is provided. +""" + +from dataclasses import dataclass +from typing import Callable + + +@dataclass(frozen=True) +class SubLaunchSpec: + """One sub-launch's standalone definition. + + Used by Cell A/B/C to invoke the sub-launch as its own xrt.run() call. + Cell D ignores SubLaunchSpec entirely (it uses the merged ELF). + """ + + name: str # "rmsnorm" | "q_gemm" | "rope_q" | ... + builder_ref: Callable # returns a 1-launch mlir.Module at production shape + build_kwargs: dict # passed verbatim to builder_ref + weight_slot_in_standalone: ( + int | None + ) # arg slot of the standalone call holding the weight (or None) + output_slot_in_standalone: int # arg slot of the standalone call holding the output + + +@dataclass(frozen=True) +class BatonLink: + """An intermediate-BO alias to apply in Cell C. + + The producer's output BO becomes the consumer's input BO; the host + skips writing the consumer's input slot via intermediate_indices. + """ + + producer_idx: int # index into KernelGroupSpec.sub_launches + producer_out_slot: int # output slot of producer's standalone signature + consumer_idx: ( + int # index into KernelGroupSpec.sub_launches (must be > producer_idx) + ) + consumer_in_slot: int # input slot of consumer's standalone signature + + +@dataclass(frozen=True) +class KernelGroupSpec: + """Full description of a multi-launch kernel-group for ablation.""" + + name: str # "rms_gemms_rope" | "o_ffn" + sub_launches: tuple # tuple of SubLaunchSpec (frozen) + merged_arg_signature: ( + tuple # tuple of arg-name strings matching production merged ELF args + ) + weight_slots: frozenset # slots in merged signature that are weights/LUTs (Cell D static_input_indices) + intermediate_slots: ( + frozenset # slots in merged signature that are kernel-overwritten intermediates + ) + output_slots_for_validation: tuple # slots whose bytes go in the golden npz + baton_links: tuple # tuple of BatonLink (Cell C aliases these intermediate BOs) + + +def validate_baton_links(sub_launches, baton_links): + """Sanity check: each link's consumer must come after its producer in the sequence.""" + for link in baton_links: + if link.consumer_idx <= link.producer_idx: + raise ValueError( + f"baton link consumer_idx={link.consumer_idx} must be greater than " + f"producer_idx={link.producer_idx}" + ) + if link.producer_idx >= len(sub_launches): + raise ValueError(f"producer_idx {link.producer_idx} out of range") + if link.consumer_idx >= len(sub_launches): + raise ValueError(f"consumer_idx {link.consumer_idx} out of range") diff --git a/programming_examples/llama32_1b/ablation/prefill/specs/o_ffn.py b/programming_examples/llama32_1b/ablation/prefill/specs/o_ffn.py new file mode 100644 index 000000000..0fa08a12f --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/specs/o_ffn.py @@ -0,0 +1,322 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Concrete KernelGroupSpec for the prefill o_ffn kernel-group. + +Mirrors the production stitch-spec in multi_launch_builder/o_ffn_multi.py. +8 sequential launches at seq=2048, emb_dim=2048, hidden_dim=8192: + + L1 o_gemm [8,4] attn_out x wo -> proj + L2 res_add [8,1] proj + x_residual -> res1 (2D out) + L3 ffn_rmsnorm [8,1] res1 x ffn_norm_w -> normed2 + L4 gate_gemm [8,4] normed2 x w_gate -> gate + L5 up_gemm [8,4] normed2 x w_up -> up + L6 swiglu [8,1] SiLU(gate) x up -> swiglu + L7 down_gemm [8,4] swiglu x w_down -> down + L8 ffn_add [8,1] down + res1 -> output (1D out) + +15 merged-func args (slots 0-14); static slots {1,5,7,9,12}; +intermediate slots {2,4,6,8,10,11,13,14}. + +Slot conventions per sub-launch standalone signatures: + - gemm: (A[seq,K], B[K,N], C[seq,N]) weight=1, out=2 + - add_2d_to_2d: (A[seq,d], B[seq,d], C[seq,d]) no weight, out=2 + - rmsnorm: (x[seq,d], w[d], out[seq,d]) weight=1, out=2 + - swiglu_2d: (gate[seq,h], up[seq,h], out[seq,h]) no weight, out=2 + - ffn_add: (A[seq,d], B[seq,d], out[n_total]) no weight, out=2 +""" + +from ml_dtypes import bfloat16 + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + +# --------------------------------------------------------------------------- +# Sub-launch standalone builders +# --------------------------------------------------------------------------- + + +def _build_o_gemm_standalone(): + """O projection GEMM: attn_out(2048,2048) x wo(2048,2048) -> proj(2048,2048).""" + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + 2048, + 2048, + tile_m=64, + tile_k_l2=256, + tile_k_l1=32, + tile_n=64, + herd_m=8, + herd_n=4, + ) + + +def _build_res_add_standalone(): + """Residual add (2D→2D): proj + x_residual -> res1.""" + from multi_launch_builder.o_ffn_multi import _build_add_2d_to_2d + + return _build_add_2d_to_2d(2048, 2048, bfloat16) + + +def _build_rmsnorm_standalone(): + """FFN RMSNorm (bare herd → wrap in air.launch).""" + from weighted_rms_norm.weighted_rms_norm import build_module as build_rms + from kernel_builder.stitching import _wrap_ir_in_launch + from air.ir import Module + + bare = str(build_rms(2048, 2048, bfloat16, 16, herd_x=8)) + return Module.parse(_wrap_ir_in_launch(bare)) + + +def _build_gateup_gemm_standalone(n): + """Gate or Up GEMM: normed2(2048,2048) x w(2048,n) -> out(2048,n).""" + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + 2048, + n, + tile_m=64, + tile_k_l2=64, + tile_k_l1=32, + tile_n=128, + herd_m=8, + herd_n=4, + ) + + +def _build_swiglu_standalone(): + """SwiGLU activation: SiLU(gate) * up -> swiglu (2D memref variant). + + Uses build_module_2d from kernel_builder/ffn_swiglu/silu_and_mul.py. + Signature: (rows, cols, tile_n, np_dtype_in, herd_x=8, herd_y=1). + Already wraps in air.launch — no _wrap_ir_in_launch needed. + Arg slots in standalone: 0=gate, 1=up, 2=out. + """ + from kernel_builder.ffn_swiglu.silu_and_mul import build_module_2d as build_swiglu + + return build_swiglu(2048, 8192, 4096, bfloat16, herd_x=8, herd_y=1) + + +def _build_down_gemm_standalone(): + """Down GEMM: swiglu(2048,8192) x w_down(8192,2048) -> down(2048,2048).""" + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + 8192, + 2048, + tile_m=64, + tile_k_l2=256, + tile_k_l1=32, + tile_n=64, + herd_m=8, + herd_n=4, + ) + + +def _build_ffn_add_standalone(): + """FFN Add (2D inputs → 1D output): down + res1 -> output[n_total]. + + Replicated from the nested _build_add_2d_to_1d() in o_ffn_multi.py + (that function is defined inline inside build_o_ffn_module and cannot + be imported directly). + + Arg slots: 0=A (down, 2D), 1=B (res1, 2D), 2=out (1D). + """ + from air.ir import ( + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineMapAttr, + AffineSymbolExpr, + IntegerAttr, + IntegerType, + MemRefType, + VectorType, + UnitAttr, + StringAttr, + ) + from air.dialects.affine import apply as affine_apply + from air.dialects.air import launch, segment, herd, module_builder + from air.dialects.memref import ( + collapse_shape as memref_collapse_shape, + AllocOp, + DeallocOp, + subview, + ) + from air.dialects.func import FuncOp + from air.dialects.scf import for_, yield_ + from air.dialects import arith + from air.dialects.vector import transfer_read, transfer_write + from air.backend.xrt_runner import type_mapper + from air.dialects.air import MemorySpace + + seq_len = 2048 + emb_dim = 2048 + n_total = seq_len * emb_dim + total_tiles = 8 + chunk_size = n_total // total_tiles + tile_n = emb_dim + + @module_builder + def _build(): + xrt_dtype = type_mapper(bfloat16) + l3_2d_ty = MemRefType.get([seq_len, emb_dim], xrt_dtype) + l3_1d_ty = MemRefType.get([n_total], xrt_dtype) + l1_space = IntegerAttr.get(IntegerType.get_signless(32), MemorySpace.L1) + l1_ty = MemRefType.get([tile_n], xrt_dtype, memory_space=l1_space) + vec_ty = VectorType.get([16], xrt_dtype) + identity_map = AffineMapAttr.get(AffineMap.get_identity(1)) + + @FuncOp.from_py_func(l3_2d_ty, l3_2d_ty, l3_1d_ty) + def eltwise_add(a_2d, b_2d, out_1d): + @launch(operands=[a_2d, b_2d, out_1d]) + def add_launch(l_a, l_b, l_out): + a_flat = memref_collapse_shape(l3_1d_ty, l_a, [[0, 1]]) + b_flat = memref_collapse_shape(l3_1d_ty, l_b, [[0, 1]]) + + @segment(name="add_seg", operands=[a_flat, b_flat, l_out]) + def add_seg(s_a, s_b, s_out): + offset_map = AffineMap.get( + 0, + 3, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(0), + AffineExpr.get_mul( + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(1), + ), + AffineSymbolExpr.get(2), + ), + AffineConstantExpr.get(chunk_size), + ), + ) + ], + ) + + @herd( + name="add_herd", + sizes=[8, 1], + operands=[s_a, s_b, s_out], + ) + def add_body(_tx, _ty, _sx, _sy, h_a, h_b, h_out): + l1_a = AllocOp(l1_ty, [], []) + l1_b = AllocOp(l1_ty, [], []) + l1_out = AllocOp(l1_ty, [], []) + c0 = arith.ConstantOp.create_index(0) + cst0 = arith.ConstantOp(xrt_dtype, 0.0) + for loop_iv in for_(0, chunk_size, tile_n): + offset = affine_apply(offset_map, [loop_iv, _tx, _ty]) + from air.dialects.air import dma_memcpy_nd + + dma_memcpy_nd( + l1_a, + h_a, + src_offsets=[offset], + src_sizes=[tile_n], + src_strides=[1], + ) + dma_memcpy_nd( + l1_b, + h_b, + src_offsets=[offset], + src_sizes=[tile_n], + src_strides=[1], + ) + for j in for_(0, tile_n, 16): + sub_a = subview(l1_a.result, [j], [16], [1]) + sub_b = subview(l1_b.result, [j], [16], [1]) + sub_out = subview(l1_out.result, [j], [16], [1]) + v_a = transfer_read( + vec_ty, sub_a, [c0], identity_map, cst0, [True] + ) + v_b = transfer_read( + vec_ty, sub_b, [c0], identity_map, cst0, [True] + ) + v_sum = arith.addf(v_a, v_b) + transfer_write( + None, v_sum, sub_out, [c0], identity_map, [True] + ) + yield_([]) + dma_memcpy_nd( + h_out, + l1_out, + dst_offsets=[offset], + dst_sizes=[tile_n], + dst_strides=[1], + ) + yield_([]) + DeallocOp(l1_a) + DeallocOp(l1_b) + DeallocOp(l1_out) + + return _build() + + +# --------------------------------------------------------------------------- +# KernelGroupSpec +# --------------------------------------------------------------------------- + +SPEC = KernelGroupSpec( + name="o_ffn", + sub_launches=( + # idx=0: O GEMM — weight at slot 1 (wo), output at slot 2 (proj) + SubLaunchSpec("o_gemm", _build_o_gemm_standalone, {}, 1, 2), + # idx=1: Res Add — no weight, output at slot 2 (res1[2D]) + SubLaunchSpec("res_add", _build_res_add_standalone, {}, None, 2), + # idx=2: FFN RMSNorm — weight at slot 1 (ffn_norm_w), output at slot 2 (normed2) + SubLaunchSpec("ffn_rmsnorm", _build_rmsnorm_standalone, {}, 1, 2), + # idx=3: Gate GEMM — weight at slot 1 (w_gate), output at slot 2 (gate) + SubLaunchSpec("gate_gemm", _build_gateup_gemm_standalone, {"n": 8192}, 1, 2), + # idx=4: Up GEMM — weight at slot 1 (w_up), output at slot 2 (up) + SubLaunchSpec("up_gemm", _build_gateup_gemm_standalone, {"n": 8192}, 1, 2), + # idx=5: SwiGLU — no weight, gate=slot0, up=slot1, output at slot 2 + SubLaunchSpec("swiglu", _build_swiglu_standalone, {}, None, 2), + # idx=6: Down GEMM — weight at slot 1 (w_down), output at slot 2 (down) + SubLaunchSpec("down_gemm", _build_down_gemm_standalone, {}, 1, 2), + # idx=7: FFN Add — no weight, A=slot0 (down), B=slot1 (res1), output at slot 2 + SubLaunchSpec("ffn_add", _build_ffn_add_standalone, {}, None, 2), + ), + merged_arg_signature=( + "attn_out", # 0 activation input + "wo", # 1 weight (static) + "proj", # 2 intermediate + "x_residual", # 3 activation input + "res1", # 4 intermediate (shared: res_add out + ffn_add B) + "ffn_norm_w", # 5 weight (static) + "normed2", # 6 intermediate + "w_gate", # 7 weight (static) + "gate", # 8 intermediate + "w_up", # 9 weight (static) + "up", # 10 intermediate + "swiglu", # 11 intermediate + "w_down", # 12 weight (static) + "down", # 13 intermediate + "output", # 14 intermediate (final 1D output) + ), + weight_slots=frozenset({1, 5, 7, 9, 12}), + intermediate_slots=frozenset({2, 4, 6, 8, 10, 11, 13, 14}), + output_slots_for_validation=(14,), + baton_links=( + # Stitch arg_map verified against o_ffn_multi.py lines 457-465: + # L1 {0:0,1:1,2:2} L2 {0:2,1:3,2:4} L3 {0:4,1:5,2:6} + # L4 {0:6,1:7,2:8} L5 {0:6,1:9,2:10} L6 {0:8,1:10,2:11} + # L7 {0:11,1:12,2:13} L8 {0:13,1:4,2:14} + BatonLink(0, 2, 1, 0), # o_gemm.proj (slot2) -> res_add.A (slot0) + BatonLink(1, 2, 2, 0), # res_add.res1 (slot2) -> ffn_rmsnorm.x (slot0) + BatonLink(2, 2, 3, 0), # ffn_rmsnorm.normed2 (slot2) -> gate_gemm.x (slot0) + BatonLink(2, 2, 4, 0), # ffn_rmsnorm.normed2 (slot2) -> up_gemm.x (slot0) + BatonLink(3, 2, 5, 0), # gate_gemm.gate (slot2) -> swiglu.gate (slot0) + BatonLink(4, 2, 5, 1), # up_gemm.up (slot2) -> swiglu.up (slot1) + BatonLink(5, 2, 6, 0), # swiglu.swiglu (slot2) -> down_gemm.x (slot0) + BatonLink(6, 2, 7, 0), # down_gemm.down (slot2) -> ffn_add.A (slot0) + BatonLink( + 1, 2, 7, 1 + ), # res_add.res1 (slot2) -> ffn_add.B (slot1) [residual-of-residual] + ), +) diff --git a/programming_examples/llama32_1b/ablation/prefill/specs/rms_gemms_rope.py b/programming_examples/llama32_1b/ablation/prefill/specs/rms_gemms_rope.py new file mode 100644 index 000000000..70d991c97 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/specs/rms_gemms_rope.py @@ -0,0 +1,130 @@ +"""Concrete KernelGroupSpec for the prefill rms_gemms_rope kernel-group. + +Mirrors the production stitch-spec in +multi_launch_builder/rms_gemms_rope_multi.py:467-474 (which lists the +arg mappings for the 6 sub-launches in the merged ELF). + +Slot conventions for standalones: + - rmsnorm: (x_in[seq, emb], norm_w[emb], out[seq, emb]) output at slot 2 + - gemm: (a[seq, K], b[K, N], c[seq, N]) output at slot 2 + (kernel_builder/gemm_builder.py:107 signature is (m, k, n, ...) — + no positional M arg; weight at slot 1, output at slot 2.) + - rope_2d: (in_2d[rows, cols], lut_1d[N], out_2d[rows, cols]) output at slot 2 +""" + +from ml_dtypes import bfloat16 + +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + + +def _build_rmsnorm_standalone(): + """Wrap weighted_rms_norm in air.launch+segment for solo invocation.""" + from weighted_rms_norm.weighted_rms_norm import build_module as build_rms + from kernel_builder.stitching import _wrap_ir_in_launch + from air.ir import Module + + bare = str(build_rms(2048, 2048, bfloat16, 16, herd_x=8)) + wrapped_text = _wrap_ir_in_launch(bare) + return Module.parse(wrapped_text) + + +def _build_gemm_standalone(k, n): + """Production prefill GEMM: (seq=2048, k, n) with the production tile config. + + _build_gemm_module signature: (m, k, n, tile_m, tile_k_l2, tile_k_l1, tile_n, + herd_m, herd_n). Slots in standalone: 0=A (activation), 1=B (weight), 2=C (output). + """ + from kernel_builder.gemm_builder import _build_gemm_module + + return _build_gemm_module( + 2048, + k, + n, + tile_m=64, + tile_k_l2=64, + tile_k_l1=32, + tile_n=128, + herd_m=8, + herd_n=4, + ) + + +def _build_rope_2d_standalone(outer_rows, outer_cols): + from multi_launch_builder.rms_gemms_rope_multi import _build_rope_2d + + return _build_rope_2d(outer_rows, outer_cols, 64, bfloat16, herd_x=8) + + +SPEC = KernelGroupSpec( + name="rms_gemms_rope", + sub_launches=( + SubLaunchSpec("rmsnorm", _build_rmsnorm_standalone, {}, 1, 2), + SubLaunchSpec("q_gemm", _build_gemm_standalone, {"k": 2048, "n": 2048}, 1, 2), + SubLaunchSpec("k_gemm", _build_gemm_standalone, {"k": 2048, "n": 512}, 1, 2), + SubLaunchSpec("v_gemm", _build_gemm_standalone, {"k": 2048, "n": 512}, 1, 2), + SubLaunchSpec( + "rope_q", + _build_rope_2d_standalone, + {"outer_rows": 2048, "outer_cols": 2048}, + 1, + 2, + ), + SubLaunchSpec( + "rope_k", + _build_rope_2d_standalone, + {"outer_rows": 2048, "outer_cols": 512}, + 1, + 2, + ), + ), + merged_arg_signature=( + "x_in", + "norm_w", + "normed", + "wq", + "q", + "wk", + "k", + "wv", + "v", + "lut_q", + "lut_k", + "q_roped", + "k_roped", + ), + weight_slots=frozenset({1, 3, 5, 7, 9, 10}), + intermediate_slots=frozenset({2, 4, 6, 8, 11, 12}), + output_slots_for_validation=(2, 4, 6, 8, 11, 12), + baton_links=( + BatonLink( + producer_idx=0, + producer_out_slot=2, + consumer_idx=1, + consumer_in_slot=0, + ), # rmsnorm.normed -> q_gemm.x + BatonLink( + producer_idx=0, + producer_out_slot=2, + consumer_idx=2, + consumer_in_slot=0, + ), # rmsnorm.normed -> k_gemm.x + BatonLink( + producer_idx=0, + producer_out_slot=2, + consumer_idx=3, + consumer_in_slot=0, + ), # rmsnorm.normed -> v_gemm.x + BatonLink( + producer_idx=1, + producer_out_slot=2, + consumer_idx=4, + consumer_in_slot=0, + ), # q_gemm.q -> rope_q.in + BatonLink( + producer_idx=2, + producer_out_slot=2, + consumer_idx=5, + consumer_in_slot=0, + ), # k_gemm.k -> rope_k.in + ), +) diff --git a/programming_examples/llama32_1b/ablation/prefill/standalone_builders/__init__.py b/programming_examples/llama32_1b/ablation/prefill/standalone_builders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/prefill/standalone_builders/o_ffn.py b/programming_examples/llama32_1b/ablation/prefill/standalone_builders/o_ffn.py new file mode 100644 index 000000000..4df578e17 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/standalone_builders/o_ffn.py @@ -0,0 +1,10 @@ +"""Single-launch standalone modules for the prefill o_ffn kernel-group. + +Exports a STANDALONES registry compatible with cells/common.py:compile_standalone_kernels. +""" + +from specs.o_ffn import SPEC + +STANDALONES = [ + (sub.name, sub.builder_ref, sub.build_kwargs) for sub in SPEC.sub_launches +] diff --git a/programming_examples/llama32_1b/ablation/prefill/standalone_builders/rms_gemms_rope.py b/programming_examples/llama32_1b/ablation/prefill/standalone_builders/rms_gemms_rope.py new file mode 100644 index 000000000..8b83e111c --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/standalone_builders/rms_gemms_rope.py @@ -0,0 +1,11 @@ +"""Single-launch standalone modules for the prefill rms_gemms_rope kernel-group. + +Exports a STANDALONES registry compatible with cells/common.py:compile_standalone_kernels. +Each entry: (name, build_fn, build_kwargs). +""" + +from specs.rms_gemms_rope import SPEC + +STANDALONES = [ + (sub.name, sub.builder_ref, sub.build_kwargs) for sub in SPEC.sub_launches +] diff --git a/programming_examples/llama32_1b/ablation/prefill/tests/__init__.py b/programming_examples/llama32_1b/ablation/prefill/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/ablation/prefill/tests/conftest.py b/programming_examples/llama32_1b/ablation/prefill/tests/conftest.py new file mode 100644 index 000000000..484728c8c --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/tests/conftest.py @@ -0,0 +1,28 @@ +"""Pytest config for prefill ablation tests. + +Inserts paths so tests can import: +- llama32_1b/ packages (kernel_builder, multi_launch_builder) +- llama32_1b/ablation/ (Plan 1's validate.py and shared helpers) +- llama32_1b/ablation/prefill/ (this package) +- programming_examples/ (matvec, weighted_rms_norm, ffn_swiglu) +""" + +import os +import sys + +_THIS = os.path.dirname(os.path.abspath(__file__)) +_PREFILL = os.path.dirname(_THIS) +_ABLATION = os.path.dirname(_PREFILL) +_LLAMA = os.path.dirname(_ABLATION) +_PROG_EXAMPLES = os.path.dirname(_LLAMA) + +for p in (_PROG_EXAMPLES, _LLAMA, _ABLATION, _PREFILL): + if p not in sys.path: + sys.path.insert(0, p) + +# Pytest's package-import mode inserts the package parent (ablation/) into sys.path[0] +# before this conftest runs, which can shadow prefill/validate.py with ablation/validate.py. +# Guarantee that prefill/ is at index 0 so prefill-local modules take priority. +if sys.path[0] != _PREFILL: + sys.path.remove(_PREFILL) if _PREFILL in sys.path else None + sys.path.insert(0, _PREFILL) diff --git a/programming_examples/llama32_1b/ablation/prefill/tests/test_kernel_group_spec.py b/programming_examples/llama32_1b/ablation/prefill/tests/test_kernel_group_spec.py new file mode 100644 index 000000000..8fd92f0d9 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/tests/test_kernel_group_spec.py @@ -0,0 +1,56 @@ +"""Unit tests for the KernelGroupSpec dataclasses.""" + +import pytest +from specs.kernel_group import SubLaunchSpec, BatonLink, KernelGroupSpec + + +def _dummy_builder(): + return None # Spec test doesn't need a real builder + + +def test_sublaunch_spec_is_frozen(): + s = SubLaunchSpec( + name="rms", + builder_ref=_dummy_builder, + build_kwargs={"emb_dim": 2048}, + weight_slot_in_standalone=1, + output_slot_in_standalone=2, + ) + with pytest.raises((AttributeError, TypeError)): # frozen + s.name = "other" + + +def test_baton_link_orders_by_indices(): + link = BatonLink( + producer_idx=0, producer_out_slot=2, consumer_idx=1, consumer_in_slot=1 + ) + assert link.consumer_idx > link.producer_idx + + +def test_kernel_group_spec_holds_sublaunches(): + sub = SubLaunchSpec("rms", _dummy_builder, {}, 1, 2) + spec = KernelGroupSpec( + name="rms_gemms_rope", + sub_launches=(sub,), # tuple — frozen dataclass + merged_arg_signature=("x_in", "norm_w", "normed"), + weight_slots=frozenset({1}), + intermediate_slots=frozenset({2}), + output_slots_for_validation=(2,), + baton_links=(), + ) + assert spec.name == "rms_gemms_rope" + assert len(spec.sub_launches) == 1 + + +def test_baton_link_consumer_must_follow_producer(): + """A baton link with consumer_idx <= producer_idx is meaningless; + spec dataclass tolerates it but a validator rejects.""" + from specs.kernel_group import validate_baton_links + + sub_a = SubLaunchSpec("a", _dummy_builder, {}, 1, 2) + sub_b = SubLaunchSpec("b", _dummy_builder, {}, 1, 2) + bad = BatonLink( + producer_idx=1, producer_out_slot=2, consumer_idx=0, consumer_in_slot=1 + ) + with pytest.raises(ValueError, match="consumer_idx"): + validate_baton_links([sub_a, sub_b], [bad]) diff --git a/programming_examples/llama32_1b/ablation/prefill/tests/test_validation_gate.py b/programming_examples/llama32_1b/ablation/prefill/tests/test_validation_gate.py new file mode 100644 index 000000000..3589bcc43 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/tests/test_validation_gate.py @@ -0,0 +1,48 @@ +"""Test the prefill validation gate against the committed goldens.""" + +import os + +import numpy as np +import pytest +from ml_dtypes import bfloat16 + +from validate import validate_against_golden, GoldenMismatch + +GOLDEN_DIR = os.path.join(os.path.dirname(__file__), "..", "golden") + + +def _load(filename): + npz = np.load(os.path.join(GOLDEN_DIR, filename)) + return {k: npz[k] for k in npz.files} + + +def test_rms_gemms_rope_passes_on_exact_match(): + g = _load("golden_rms_gemms_rope_prefill.npz") + validate_against_golden(g, GOLDEN_DIR, "golden_rms_gemms_rope_prefill.npz") + + +def test_rms_gemms_rope_raises_on_byte_diff(): + g = _load("golden_rms_gemms_rope_prefill.npz") + perturbed = {k: v.copy() for k, v in g.items()} + arr = perturbed["normed"].view(np.uint8).copy() + arr[0] ^= 0x01 + perturbed["normed"] = arr.view(bfloat16).reshape(g["normed"].shape) + with pytest.raises(GoldenMismatch, match="normed"): + validate_against_golden( + perturbed, GOLDEN_DIR, "golden_rms_gemms_rope_prefill.npz" + ) + + +def test_o_ffn_passes_on_exact_match(): + g = _load("golden_o_ffn_prefill.npz") + validate_against_golden(g, GOLDEN_DIR, "golden_o_ffn_prefill.npz") + + +def test_o_ffn_raises_on_byte_diff(): + g = _load("golden_o_ffn_prefill.npz") + perturbed = {k: v.copy() for k, v in g.items()} + arr = perturbed["output"].view(np.uint8).copy() + arr[0] ^= 0x01 + perturbed["output"] = arr.view(bfloat16).reshape(g["output"].shape) + with pytest.raises(GoldenMismatch, match="output"): + validate_against_golden(perturbed, GOLDEN_DIR, "golden_o_ffn_prefill.npz") diff --git a/programming_examples/llama32_1b/ablation/prefill/validate.py b/programming_examples/llama32_1b/ablation/prefill/validate.py new file mode 100644 index 000000000..e5ae14539 --- /dev/null +++ b/programming_examples/llama32_1b/ablation/prefill/validate.py @@ -0,0 +1,49 @@ +"""Per-cell validation — parameterized version of Plan 1's validate.py. + +Plan 1's validate.py hardcodes the golden filename to +"golden_rms_gemv_rope.npz". For prefill we have two goldens, so we +parameterize the filename. The byte-equality contract is identical. +""" + +import os + +import numpy as np + + +class GoldenMismatch(AssertionError): + """Raised when a cell's output diverges from the committed golden.""" + + +def validate_against_golden(cell_outputs: dict, golden_dir: str, npz_filename: str): + """Compare every key in cell_outputs to the matching array in + /. Raise GoldenMismatch on any diff.""" + npz = np.load(os.path.join(golden_dir, npz_filename)) + for key in npz.files: + if key not in cell_outputs: + raise GoldenMismatch(f"cell missing output '{key}'") + gv = npz[key] + cv = cell_outputs[key] + if cv.shape != gv.shape: + raise GoldenMismatch( + f"{key}: shape mismatch cell={cv.shape} golden={gv.shape}" + ) + if cv.dtype.itemsize != gv.dtype.itemsize: + raise GoldenMismatch(f"{key}: itemsize mismatch") + if cv.tobytes() != gv.tobytes(): + from ml_dtypes import bfloat16 as _bf16 + + cf = ( + cv.view(np.uint8).view(_bf16).astype(np.float32) + if cv.dtype != np.float32 + else cv + ) + gf = ( + gv.view(np.uint8).view(_bf16).astype(np.float32) + if gv.dtype != np.float32 + else gv + ) + max_abs = float(np.max(np.abs(cf - gf))) + max_rel = float(np.max(np.abs((cf - gf) / (np.abs(gf) + 1e-9)))) + raise GoldenMismatch( + f"{key}: byte mismatch max_abs={max_abs:.4g} max_rel={max_rel:.4g}" + ) diff --git a/programming_examples/llama32_1b/docs/ABLATION_STUDY.html b/programming_examples/llama32_1b/docs/ABLATION_STUDY.html new file mode 100644 index 000000000..520a0f6df --- /dev/null +++ b/programming_examples/llama32_1b/docs/ABLATION_STUDY.html @@ -0,0 +1,830 @@ + + + + +Llama-3.2-1B on AMD NPU2 — Ablation Study + + + + + + + + + +

Llama-3.2-1B on AMD NPU2 — Ablation Study

+

Quantifying which dispatch optimizations contribute how much to the production runtime. Companion to IMPLEMENTATION_GUIDE.html Part B3-B7 (the four gaps).

+ +
+ What this document is. A walkthrough of the 4-cell ablation study covering what we measured and what the cells differ on. Two studies: decode (full per-token) and prefill (full 16-layer). If you just want the punchline numbers, jump to decode results, prefill results, or cross-comparison. +
+ + +

Background — what's running on the NPU

+ +

Llama-3.2-1B is a 1.24 B-parameter decoder-only transformer (16 layers, emb=2048, n_heads=32, n_kv_heads=8, head_dim=64, hidden=8192, BF16). On AMD NPU2 it runs as 6 production ELFs orchestrated from a Python host. Each ELF is one or more air.launchs stitched into a single xrt.run(). Per pass:

+ + + + + + + + + + Prefill — per layer + 3 NPU calls per layer × 16 layers + 1 LM head = 49 NPU calls/pass + + + + rms_gemms_rope.elf + 6 stitched launches + RMSNorm + Q/K/V GEMM + 2× RoPE + + + + + flash_attn.elf + 1 launch (un-mergeable) + FA causal GQA + + + + + o_ffn.elf + 8 stitched launches + O GEMM + Add + RMSNorm + SwiGLU + Add + + → loop L+1 + + + + Decode — per token, per layer + 2 NPU calls + 1 CPU step per layer × 16 + 1 LM head = 33 NPU calls/token + + + + rms_gemv_rope.elf + 6 stitched launches + RMSNorm + Q/K/V GEMV + 2× RoPE + + + + + decode_attn (CPU) + single-query GQA + + KV cache append + + + + + o_gemv_ffn.elf + 8 stitched launches + O GEMV + Add + RMSNorm + SwiGLU + Add + + → loop L+1 + + + + + after 16 layers + + lm_head_gemv.elf — shared by both phases + 8-partition GEMV in 1 xrt.run() → argmax + + +

Three observations matter for the ablation that follows:

+
    +
  • Production already uses multi-launch ELF stitching. Each box above hides 1, 6, or 8 sub-launches but appears to the host as a single xrt.run(). The naive baseline (Cell A) instead launches every sub-kernel as its own xrt.run() — so a naive prefill issues 240 dispatches per pass instead of 48, and a naive decode issues ~96 dispatches per token instead of 33.
  • +
  • FlashAttention sits between two stitchable groups. FA is its own ELF (un-mergeable into the surrounding rms_gemms_rope or o_ffn — see IMPLEMENTATION_GUIDE B5). So even Cell D has 3 dispatches per prefill layer, not 1.
  • +
  • Decode uses CPU attention. Per-token attention has small enough work to be cheaper on CPU than on the NPU FA path at head_dim=64. So decode's per-layer dispatch is 2 NPU calls + 1 CPU step.
  • +
+ +

The ablation runs three studies on this dispatch picture:

+ + + + + + + + + + + + + +
StudyScopeHeadline result
Decode (Part 3)Both decode kernel-groups + CPU attention + LM head, full per-token loop (the full decode row above)Cell D = 90.65 ms/token; A→D = 2.83×
Prefill (Part 4)Both prefill kernel-groups + FA, 16 layers (the full prefill row above)Cell D = 1.13 s ≈ profile.md's 1.27 s; A→D = 1.56×
+ + + +

Part 1 — High level: what are we measuring?

+ +

1.1 The question

+ +

The production runtime achieves 1.27 s prefill (per profile.md) and a per-token decode latency much smaller than a naive implementation. The IMPLEMENTATION_GUIDE B3 argues that this comes from solving four "gaps" between standalone kernels and end-to-end inference:

+ + + + + + + +
GapSection in IMPLEMENTATION_GUIDE
#1 — XRT dispatch overhead (multi-launch ELF stitching)B5
#2 — Per-call BO management, weights pre-loaded once (per-layer weight BOs)B6 + B7
#3 — Intermediate buffers shared across separate xrt.run() calls (only relevant in the un-merged baseline)B6
#4 — KernelCache compile-once + per-process caching (not in this ablation; held constant)B7
+ +

"We built X, Y, Z, and inference got faster" doesn't tell us how much each individual change matters. The ablation builds a 4-cell ladder that adds the optimizations one at a time on top of a naive baseline, so each cell isolates the marginal contribution of a single optimization.

+ + +

1.2 The 4-cell ladder (A → B → C → D)

+ +

Each cell runs the SAME computation on the SAME input. Only the dispatch strategy changes. The cells are cumulative: each one keeps the previous cell's optimizations and adds one more.

+ +
+
+
Cell A Naive baseline
+

One xrt.run() per sub-kernel. Every call writes every input slot to device, runs the kernel, reads every output back. KernelCache invoked with naive=True so the index-set optimizations are disabled.

+

Adds: nothing (baseline)

+
+
+
Cell B + per-layer weight BOs (gap #2)
+

Same N xrt.run()s as A, but weights pre-loaded once into per-layer BOs. static_input_indices tells KernelCache to skip the host write for those slots on every call.

+

Adds: gap #2 alone

+
+
+
Cell C + shared intermediate BOs (gap #3)
+

Still N separate xrt.run()s, but each producer's output BO is aliased to the next consumer's input BO via _share_bo. So the host doesn't transport intermediates between calls — they stay in the same DDR region.

+

Adds: gap #3 alone

+
+
+
Cell D + multi-launch merging (gap #1) = production
+

One merged ELF containing all N air.launchs. ONE xrt.run() drives them all. Intermediates flow through DDR via NPU DMA, never through the host. This is exactly what production uses.

+

Adds: gap #1 alone

+
+
+ +

Reading the deltas:

+
    +
  • A → B = isolated effect of gap #2 (per-layer weight BOs)
  • +
  • B → C = isolated effect of gap #3 (shared intermediate BOs)
  • +
  • C → D = isolated effect of gap #1 (multi-launch merging) — the "pure merging" delta
  • +
  • A → D = total speedup of all three together
  • +
+ +
+ Why this ordering matters for fair attribution. Gap #1 (merging) and gap #3 (BO sharing) are alternative ways to keep intermediates on device. If you measured C→D in isolation (without first applying B and C), you might conflate the two. The ladder ordering A→B→C→D ensures gap #1's marginal effect (C→D) measures ONLY the host-orchestration savings beyond what BO-sharing already provides — i.e., the cost of N kernel dispatches vs. 1. +
+ + +

1.3 Two studies — decode and prefill

+ +

The 4-cell ladder is applied at two scopes — one per inference phase:

+ + + + + + + +
Decode (Part 3)Prefill (Part 4)
ScopeBOTH decode kernel-groups (rms_gemv_rope, o_gemv_ffn) + CPU attention + LM head + per-token loop × 16 layersBOTH prefill kernel-groups (rms_gemms_rope, o_ffn) + FlashAttention + 16-layer wrapper
Per-cell wall time~90-260 ms per token~1.1-1.8 s per pass
Cell D matches…profile.md's per-token decode latencyprofile.md's 1.27 s prefill headline
Why bothDecode is dispatch-overhead-bound (per-call NPU work is small)Prefill has large per-call NPU work; the SAME optimizations may behave differently — we want to find out, not assume
+ + + +

Part 2 — Methodology

+ +

2.1 The unit of measurement

+ +

Each plan measures something different:

+ + + + + + + + + + + + + + + +
StudyWhat's timedWhere the timer wrapsWhat's NOT in the number
Decode (Part 3)One full per-token loop: 16 layers × (rms_gemv_rope + CPU attention + o_gemv_ffn) + final RMSNorm + lm_head_gemv + argmaxt_total_start in cells/per_token_loop.py immediately before layer 0; elapsed at the end of argmaxCompile time, BO allocation (counted as preload), KV-cache initialization (counted as preload)
Prefill (Part 4)16 layers of dispatch: per layer, rms_gemms_rope + FA + o_ffn. Includes host-side data threading between launchest_total_start in multi_layer.py:run_16_layer_prefill immediately before the first layer; elapsed at the end of the last layer's o_ffnEmbedding lookup, final RMSNorm + LM Head GEMV, KV-cache extraction transposes (~150 ms residual; accounts for the gap between Cell D's 1.13 s and profile.md's 1.27 s)
+ +

Concrete: what is one "Cell D timing"? Decode Cell D's median is 90.65 ms — the wall time of one full per-token loop with all production optimizations enabled (each layer's two NPU calls merged, weights resident in per-layer BOs). Cell A's median is 256.69 ms — the same loop but every sub-launch as its own xrt.run() with full host I/O, and weights re-uploaded each call. Same total computation, different dispatch strategy.

+ + +

2.2 Bit-exact validation gate — guarantees same computation

+ +

Every cell must produce byte-identical outputs to a committed golden fixture before its timing is reported. A cell that "ran faster" by accidentally running a different (wrong) computation is suppressed before it can show up in the report.

+ +

The mechanism

+ +

Each plan has a golden/ directory holding .npz files written by Cell D on a fixed deterministic input. Before timing begins, every cell runs once with the same input and the output bytes are compared to the golden:

+ +
# validate.py — the gate
+def validate_against_golden(cell_outputs: dict, golden_dir: str):
+    npz = np.load(os.path.join(golden_dir, "golden_rms_gemv_rope.npz"))
+    for key in npz.files:
+        gv = npz[key]
+        cv = cell_outputs[key]
+        if cv.shape != gv.shape:
+            raise GoldenMismatch(f"{key}: shape mismatch ...")
+        if cv.dtype.itemsize != gv.dtype.itemsize:
+            raise GoldenMismatch(f"{key}: dtype size mismatch ...")
+        if cv.tobytes() != gv.tobytes():           # EXACT byte equality, no tolerance
+            raise GoldenMismatch(f"{key}: byte mismatch")
+ +

Cells that fail the gate have their timing suppressed in the report, so a numerically-different "fast" cell can't sneak its way into the headline.

+ +

Why bit-exact, not tolerance-based?

+ +

BF16 numerics already have ~3-4 decimal digits of variability vs. F32. If we used a numerical tolerance like "max relative error < 1e-3", a cell could silently introduce a computation difference that changes BF16 outputs in the 4th-5th significant digit, fall under the tolerance threshold, and be falsely accepted. The 4 cells should be doing IDENTICAL computation — only dispatch differs. The kernel binaries are even the same when applicable. So the outputs should be byte-identical, and a deviation is a methodology bug to investigate (not a numerical artifact to tolerate).

+ +

Empirically all 4 cells DO produce bit-identical outputs for both decode and prefill. This is independent confirmation that the dispatch differences are purely orchestration changes — none of them re-tile or re-vectorize the kernels in any way.

+ + +

2.3 Synthetic deterministic inputs

+ +

Inputs come from golden/regen_golden.py:_synthetic_inputs(CONFIG) with numpy.random.seed(42). No HuggingFace weights are loaded.

+ +

Why synthetic, not real weights?

+
    +
  • Reproducibility. Anyone with a fresh checkout can regenerate the goldens and run the ablation without needing the ~5 GB Llama-3.2-1B weight download or HuggingFace credentials.
  • +
  • Determinism. A fixed seed makes the same kernel produce the same outputs across runs and machines, so bit-exact validation is meaningful.
  • +
  • Doesn't matter for dispatch ablation. The 4 cells differ in how data flows between kernels, not what the data means semantically. A weight tensor of N(0,1) values exercises the same DMA paths and the same MMA instructions as a real Llama weight tensor.
  • +
+ +

Limitation: the dispatch-overhead conclusions transfer to real weights, but a numerical-precision study (e.g., "does our quantization match HuggingFace's outputs to within X tolerance") would need real weights and is out of scope for this ablation.

+ + +

2.4 Timing protocol + environment

+ +

From run_ablation.py:_time_cell:

+ +
def _time_cell(run_fn, n_trials, *args):
+    """Run n_trials, drop trial 1 (warmup), median + (min, max) of remaining."""
+    times = []
+    for _ in range(n_trials):
+        out = run_fn(*args)
+        times.append(out["_wall_s"])
+    keep = times[1:]                  # drop warmup
+    keep_sorted = sorted(keep)
+    return {
+        "median_s": keep_sorted[len(keep_sorted) // 2],
+        "min_s": min(keep), "max_s": max(keep),
+        "all_trials_s": times,
+    }
+ + + + + + +
ChoiceWhy
5 trials per cellEnough samples to see variance; small enough to keep total run time ≤ 10 minutes
Drop trial 1 (warmup)First call after a fresh KernelCache load incurs one-time JIT-style XRT context warmup, instruction-cache fill, and BO-allocation costs. Trial 2+ are at steady state — what we actually want to measure
Report median + (min, max)Median is robust to one-off outliers (a kernel scheduling hiccup, a host CPU preemption). Reporting min/max exposes the variance so the reader can judge whether the median is meaningful
+ +

In practice the within-cell range is small (Decode Cell D: 90.57-90.69 ms; Prefill Cell A: 1.751-1.755 s — under 0.5% of mean). The cell-to-cell deltas are much larger than within-cell noise, so 4 timed trials give statistically meaningful conclusions.

+ +

Environment isolation. The host is multi-tenant — concurrent NPU jobs would corrupt timing. Every run acquires flock -x -w 1800 /tmp/mlir-air-npu.lock before touching the NPU; other NPU jobs block on the same lock for the duration.

+ + +

Part 3 — Decode (full per-token end-to-end)

+ +

3.1 Scope + design decisions

+ +

The 4-cell ladder applied to the production decode path: 16 layers × (rms_gemv_rope NPU + decode_attention_cpu CPU + o_gemv_ffn NPU) + final RMSNorm + lm_head_gemv NPU + argmax. CPU attention and LM head are held INVARIANT across cells (only the NPU dispatch changes between A/B/C/D). Goal: reproduce profile.md's per-token decode latency with Cell D and decompose the optimization contributions.

+ +

Design decisions made before implementation

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +
QuestionDecisionWhy
Tokens per timed trial?1 decode token per trial × 5 trials at fixed current_pos = 7Multi-token timing has position-dependent decode_attention_cpu work (CPU attention reads [0:current_pos+1] of the cache, growing each token). Single-token-at-fixed-position keeps the CPU work CONSTANT across trials and across cells, isolating dispatch overhead. Within-cell variance < 0.5%.
LM head treatment?Held INVARIANT (production-merged 8-partition GEMV in every cell)Mirrors prefill's treatment of FlashAttention. Reported as a separate "fixed cost per token" line (~13.6 ms/token) so it's visible but doesn't bias cell-to-cell deltas.
KV cache state?Deterministic synthetic pre-fill of 7 tokens (seed=42), reset between trialsEach trial starts from byte-identical cache state. tests/test_kv_cache_state.py verifies determinism.
decode_attention_cpu attribution?Counted in per-token total + reported separately as "CPU floor"It's CPU-side and invariant. Hiding it would mislead readers; reporting it separately keeps cell-to-cell deltas clean.
Production CPU attention vs experimental NPU FA?Production CPU-attention path onlyThat's what profile.md's decode latency reflects. NPU FA decode path exists for Llama-3B (head_dim=128) but isn't used at head_dim=64.
+ +

Validation

+ +

Same bit-exact gate as Part 4: every cell's per-kernel-group output must match the committed Cell D goldens (golden_rms_gemv_rope_decode.npz, golden_o_gemv_ffn_decode.npz) on the seed=42 synthetic input. All 4 cells passed validation in the production run.

+ + +

3.2 Results

+ +

Per-token total wall time

+ + + + + + + + +
CellMedianRangeΔ vs prevSpeedup
A Naive no-merge256.69 ms[256.20, 257.89](baseline)
B + per-layer weight BOs (#2)116.92 ms[114.71, 117.73]+139.77 ms2.20×
C + shared intermediate BOs (#3)113.77 ms[112.95, 114.30]+3.15 ms1.03×
D + multi-launch merging (#1)90.65 ms[90.57, 90.69]+23.12 ms1.26×
A → D total−166.04 ms2.83×
+ +

5 trials per cell, drop trial 1 (warmup), median + (min, max) over remaining 4. All 4 cells produced bit-identical outputs vs. committed Cell D goldens.

+ +

Per-kernel-group medians (single call)

+ +

Same format as the per-token total wall table above — Δ vs prev cell + speedup. Two stacked tables, one per kernel-group. Cells A/B/C dispatch each sub-launch as a separate xrt.run() (6 for rms_gemv_rope, 8 for o_gemv_ffn); Cell D collapses each kernel-group into one merged xrt.run().

+ +

rms_gemv_rope (6 sub-launches → 1 merged in D)

+ + + + + + + + +
CellMedianΔ vs prevSpeedup vs prev
A Naive (6 separate)2.40 ms(baseline)
B + per-layer weight BOs (#2)1.48 ms−0.92 ms1.62×
C + shared intermediate BOs (#3)1.44 ms−0.04 ms1.03×
D + multi-launch merging (#1, 1 merged)0.87 ms−0.57 ms1.66×
A → D total−1.53 ms2.76×
+ +

o_gemv_ffn (8 sub-launches → 1 merged in D)

+ + + + + + + + +
CellMedianΔ vs prevSpeedup vs prev
A Naive (8 separate)12.45 ms(baseline)
B + per-layer weight BOs (#2)4.62 ms−7.83 ms2.69×
C + shared intermediate BOs (#3)4.51 ms−0.11 ms1.02×
D + multi-launch merging (#1, 1 merged)3.67 ms−0.84 ms1.23×
A → D total−8.78 ms3.39×
+ +

Component breakdown (Cell D)

+ + + + + + + + +
ComponentWall timeNote
16 × rms_gemv_rope.elf~14 ms0.87 ms × 16
16 × o_gemv_ffn.elf~59 ms3.67 ms × 16
16 × decode_attention_cpu (CPU floor, invariant)3.68 msSame in every cell
1 × lm_head_gemv.elf (production-merged, invariant)13.62 ms8-partition GEMV in 1 xrt.run, held INVARIANT
Total per-token wall90.65 msSum (within rounding)
+ +

Findings

+ +
    +
  • #2 (per-layer weight BOs) DOMINATES — 2.20× alone. Per-layer weight BOs save ~140 ms per token of avoided host→device weight upload. Decode is dispatch/weight-upload bound (per-call NPU work is small relative to weight DMA cost), so eliminating that DMA is the single biggest lever.
  • +
  • #3 (shared intermediate BOs) contributes ~zero — 1.03×. Decode intermediates are KB-scale (4-8 KB each); at that size the host round-trip cost is dominated by sync + dispatch fixed overhead, not byte transfer. BO aliasing only removes byte transfer, so its benefit is invisible. (Compare prefill's 1.31× in Part 4 — there, MB-scale intermediates make the same optimization the dominant gain.)
  • +
  • #1 (multi-launch merging) gives 1.26×. Smaller as a fraction of the total because per-token wall includes ~17 ms of invariant fixed cost (LM head + CPU attention) that aren't ablation-affected.
  • +
  • Total A→D = 2.83×, dominated by #2 (per-layer weight BOs). Both NPU kernel-groups (rms_gemv_rope and o_gemv_ffn) benefit from the same optimization for the same reason.
  • +
  • All 4 cells produce bit-identical outputs for both rms_gemv_rope and o_gemv_ffn against committed goldens.
  • +
+ + + +

Part 4 — Prefill (full 16-layer)

+ +

4.1 Scope

+ +

The 4-cell ladder applied to the production prefill path: 16 layers × (rms_gemms_rope NPU + FlashAttention NPU + o_ffn NPU). FlashAttention is held INVARIANT across cells (it's un-mergeable into the surrounding kernel-groups, see B5). Goal: reproduce profile.md's 1.27 s prefill headline with Cell D and decompose the optimization contributions.

+ +

Per-call shapes are very different from decode: prefill operates at seq=2048, so a Q-GEMM output is [2048, 2048] = 8 MB bf16, and a Gate/Up GEMM output is [2048, 8192] = 32 MB bf16. Per-call NPU compute is in tens of milliseconds, not hundreds of microseconds — the bottleneck physics is fundamentally different from decode.

+ +

Dispatch counts per prefill pass: Cells A/B/C launch every sub-kernel as a separate xrt.run() — 6+1+8 = 15 dispatches per layer × 16 layers = 240 xrt.run() per pass. Cell D collapses each kernel-group into one merged ELF: 1+1+1 = 3 dispatches per layer × 16 = 48 per pass. The 5× dispatch reduction is what optimization #1 buys; it's measured by the C → D delta.

+ +

The 16-layer wrapper that threads o_ffn output → next layer's rms_gemms_rope input is in cells/multi_layer.py. The same wrapper is used by all 4 cells; only the per-kernel-group dispatch function changes.

+ + +

4.2 Results

+ +

16-layer total wall time — direct comparison to profile.md's 1.27 s

+ + + + + + + + +
CellMedian (s)RangeΔ vs prevSpeedupvs profile.md 1.27 s
A Naive1.754[1.751, 1.755](baseline)1.38× slower
B + per-layer weight BOs (#2)1.589[1.584, 1.594]+0.165 s1.10×1.25× slower
C + shared intermediate BOs (#3)1.212[1.212, 1.222]+0.377 s1.31×0.95× faster
D + multi-launch merging (#1)1.125[1.124, 1.127]+0.087 s1.08×0.89× faster
A → D total−0.629 s1.56×All three combined
+ +

Cell D = 1.125 s — close to profile.md's 1.27 s. The ~150 ms gap is host-side work outside the dispatch loop (embedding lookup, final RMSNorm, LM Head GEMV, KV-cache extraction transposes).

+ +

Per-kernel-group medians (single call)

+ +

Same format as the 16-layer total wall table above. Two stacked tables, one per kernel-group. Cells A/B/C dispatch each sub-launch as a separate xrt.run() (6 for rms_gemms_rope, 8 for o_ffn); Cell D collapses each kernel-group into one merged xrt.run().

+ +

rms_gemms_rope (6 sub-launches → 1 merged in D)

+ + + + + + + + +
CellMedianΔ vs prevSpeedup vs prev
A Naive (6 separate)14.99 ms(baseline)
B + per-layer weight BOs (#2)12.52 ms−2.47 ms1.20×
C + shared intermediate BOs (#3)9.77 ms−2.75 ms1.28×
D + multi-launch merging (#1, 1 merged)7.43 ms−2.34 ms1.31×
A → D total−7.56 ms2.02×
+ +

o_ffn (8 sub-launches → 1 merged in D)

+ + + + + + + + +
CellMedianΔ vs prevSpeedup vs prev
A Naive (8 separate)75.05 ms(baseline)
B + per-layer weight BOs (#2)64.67 ms−10.38 ms1.16×
C + shared intermediate BOs (#3)45.01 ms−19.66 ms1.44×
D + multi-launch merging (#1, 1 merged)40.99 ms−4.02 ms1.10×
A → D total−34.06 ms1.83×
+ +

Sanity check vs the 16-layer table: Cell D — 16 × (7.43 + 40.99) = 775 ms for the two kernel-groups + ~22 ms × 16 = ~350 ms for FA = ~1.12 s, matches the 1.125 s 16-layer wall. Cell A — 16 × (14.99 + 75.05) = 1441 ms + 350 ms FA = ~1.79 s, close to the 1.754 s 16-layer wall.

+ +

Findings

+ +
    +
  • #3 (shared intermediate BOs) DOMINATES — 1.31× alone. OPPOSITE of decode (where #3 ≈ 1.0×). Why: prefill intermediates are LARGE (8 MB GEMM output, 32 MB FFN intermediate), so the bandwidth saved by aliasing BOs across separate xrt.run() calls is substantial.
  • +
  • #2 (per-layer weight BOs) is small — 1.10×. Weights are still ~14 MB each, but per-call NPU compute is now ~10-50 ms (vs. < 1 ms in decode). The fraction of total time that's weight-DMA-bound shrinks dramatically.
  • +
  • Pure multi-launch merging (#1) is small — 1.08×. Same reason: dispatch overhead matters proportionally less when each kernel has tens of ms of NPU work.
  • +
  • Total A → D = 1.56× — smaller than decode's 2.83× because dispatch-related overheads are a smaller share of total wall time at prefill scale.
  • +
  • All 4 cells produce bit-identical outputs against committed goldens for both kernel-groups.
  • +
+ + +

Part 5 — Reading the results, reproducing, and limitations

+ +

5.1 Cross-comparison: decode vs. prefill

+ +

The single most surprising finding is how dramatically the contribution of each optimization SHIFTS between phases:

+ + + + + + + +
OptimizationDecode (Part 3)Prefill (Part 4)Why
#1 — Multi-launch merging1.26×1.08×Decode per-call NPU work small → dispatch overhead matters more; prefill per-call work in tens of ms → dispatch is small fraction
#2 — Per-layer weight BOs2.20×1.10×Decode weights dominate per-call cost (small compute, ~14 MB weights per call); prefill compute dominates (large compute amortizes the weight upload)
#3 — Shared intermediate BOs1.03×1.31×Decode intermediates are KB-scale → host-hop dominated by sync/dispatch overhead, byte-transfer saving invisible. Prefill MB-scale intermediates → byte-transfer saving real
A → D total2.83×1.56×Decode is dispatch-bound; prefill is more compute-bound
+ +
+ The key insight: the same 4-cell ladder yields a near-3× speedup for decode but only ~1.5× for prefill. The dominant optimization flips between phases — #2 (per-layer weight BOs) for decode, #3 (shared intermediate BOs) for prefill. Targeting the wrong one would yield a fraction of the available speedup. +
+ + +

5.2 Reproducing

+ +

Decode (Part 3)

+ +
cd programming_examples/llama32_1b/ablation/decode
+make clean
+make all                # compile + run all 4 cells, generate report
+ +

Expected: Cell D per-token median ≈ 90 ms; A → D speedup ≈ 2.8×.

+ +

Prefill (Part 4)

+ +
cd programming_examples/llama32_1b/ablation/prefill
+make clean
+make all
+ +

Expected: Cell D 16-layer total wall ≈ 1.1-1.2 s — within ~10% of profile.md's 1.27 s production headline.

+ +

Validation gate (no NPU touch)

+ +
python3 -m pytest tests/ -v
+ +

Useful as a smoke check before queuing on the shared NPU lock.

+ + +

5.3 Limitations + how to extend

+ + + + + + + + + + + + + + + + + + + + + + + + + + + +
LimitationMitigation / how to extend
Synthetic weights only (numpy seed=42)Dispatch ablation is independent of weight values — same DMA paths, same MMAs. To verify, swap in HuggingFace weights and confirm the per-cell deltas are within noise.
FlashAttention held invariant in prefill (Part 4)FA is un-mergeable into the surrounding kernel-groups (compiler pass complexity, see B5), so varying it across cells would be unfair. A follow-up could ablate aliasing FA's input BOs to rms_gemms_rope's output BOs (cross-kernel-group BO sharing).
Prefill harness doesn't include LM head, embedding, KV-cache transposeThat's why prefill Cell D = 1.125 s while profile.md = 1.27 s. The ~150 ms residual is host-side work outside the dispatch loop. (Decode harness DOES include LM head + final RMSNorm + argmax, so its Cell D matches profile.md directly.)
Single decode token at fixed positionBy design (Part 3.1) — keeps CPU attention work constant. Multi-token decode would have position-dependent CPU time that would mask dispatch ablation. To extend, run a per-position sweep separately.
BF16 bit-exactness as the validation gateCatches dispatch-induced computation differences but doesn't validate that the kernels themselves are numerically correct (i.e. agree with HuggingFace transformers bf16). That validation is the production make verify top-k token gate (see VERIFICATION.html); not duplicated here.
5 trials per cellWithin-cell variance is small (≤1% of mean) and inter-cell deltas are large (10s of percent), so 5 trials is enough for the conclusions stated. For finer-grained claims (e.g. "is gap #3 contributing exactly zero or just <1% in decode"), more trials would tighten the confidence interval.
+ + +

5.4 File map

+ + + + + +
PathPurpose
programming_examples/llama32_1b/ablation/decode/Full per-token decode harness (Part 3). Self-contained: specs, standalone builders, cells A-D, per_token_loop, KV cache helpers, goldens, orchestrator, Makefile, README, 8 unit tests. Re-uses Plan 1's parameterized infrastructure where possible.
programming_examples/llama32_1b/ablation/prefill/Full prefill harness (Part 4). Specs, standalone builders, cells A-D, multi_layer wrapper, FA invariant runner, goldens, orchestrator, Makefile, README, 8 unit tests.
+ +

Companion documents

+ + + + + + +
DocWhere it fits
IMPLEMENTATION_GUIDE.htmlSister document — describes the production codebase and the four gaps. This ablation quantifies them.
profile.mdSource of the 1.27 s prefill headline reproduced by Cell D in Part 4.
../ablation/docs/specs/ + plans/Pre-implementation specs and step-by-step plans for each study (prefill 2026-05-07, decode 2026-05-12).
+ +
+ Quick recap. Two ablation studies (decode + prefill), both implemented and measured with the same 4-cell ladder (A naive → B + weight BOs → C + intermediate BOs → D production-merged). Each cell verified bit-exact against a committed golden; timed as median over 4 trials (drop warmup); NPU exclusive-locked. +
    +
  • Decode (Part 3, full per-token end-to-end): A→D = 2.83×, Cell D = 90.65 ms/token, dominated by per-layer weight BOs (#2) at 2.20× alone.
  • +
  • Prefill (Part 4, 16-layer end-to-end): A→D = 1.56×, Cell D = 1.13 sprofile.md's 1.27 s within 10%, dominated by shared intermediate BOs (#3) at 1.31× alone.
  • +
+ The most actionable finding: which optimization dominates flips between decode and prefill. For decode: target #2 (per-layer weight BOs). For prefill: target #3 (shared intermediate BOs). Targeting the wrong one would yield a fraction of the available speedup. +
+ + + diff --git a/programming_examples/llama32_1b/docs/IMPLEMENTATION_GUIDE.html b/programming_examples/llama32_1b/docs/IMPLEMENTATION_GUIDE.html new file mode 100644 index 000000000..0a80a740d --- /dev/null +++ b/programming_examples/llama32_1b/docs/IMPLEMENTATION_GUIDE.html @@ -0,0 +1,3342 @@ + + + + +Llama-3.2-1B on AMD NPU2 — Implementation Guide + + + + +

Llama-3.2-1B on AMD NPU2 — Implementation Guide

+

A model-first walkthrough: understand what Llama-3.2-1B inference IS, then how this codebase runs it on AMD NPU2 hardware.

+ + + + + + +
+How to read this guide: Read Part A first if you're unsure what Llama-3.2-1B inference does at the math level. Part A has no NPU code — just the model itself and its data flow. Then Part B shows how this codebase realizes Part A on AMD NPU2 hardware. Part C is a one-page pointer to the verification subsystem (full design in VERIFICATION.html). Part D lists known optimizations not yet implemented. Part E is reference material to come back to as needed. +
+ + +

Part A — The Model (no NPU yet)

+ +

A1. Llama-3.2-1B at a glance

+ +
+

Llama-3.2-1B is a 1.24-billion-parameter decoder-only transformer language model from Meta, released in 2024. Given a sequence of input tokens, it produces a probability distribution over the vocabulary for the next token. Repeated autoregressively, this generates text.

+
+ +

Hyperparameters (defined in LlamaConfig at llama32_1b_weights.py:36)

+ + + + + + + + + + + + + +
ParameterValueWhat it means
n_layers16Number of stacked transformer blocks
emb_dim (d_model)2048Hidden dimension everything flows through
n_heads32Number of Q heads in attention
n_kv_heads8Number of K/V heads (GQA: 4 Q heads share each KV head)
head_dim64Per-head dimension. Note: 32 × 64 = 2048 = emb_dim
hidden_dim8192FFN intermediate width (gate/up/down projections expand to this)
vocab_size128256Tokenizer vocabulary size; LM Head outputs this many logits
seq_len2048Fixed prefill length in this implementation (not a model property)
weight dtypebfloat1616-bit brain-float for all weights and activations
RoPE base500000Rotary Position Embedding base frequency
+ +

Total parameter accounting (~1.24 B)

+ + + + + + + + + + + + + + + + + +
ComponentPer layer× 16 layersPer-tensor shape
Attention norm weight2,04832,768(2048,)
Q projection4.19 M67.1 M(2048, 2048)
K projection1.05 M16.8 M(2048, 512)
V projection1.05 M16.8 M(2048, 512)
O projection4.19 M67.1 M(2048, 2048)
FFN norm weight2,04832,768(2048,)
Gate projection16.8 M268 M(2048, 8192)
Up projection16.8 M268 M(2048, 8192)
Down projection16.8 M268 M(8192, 2048)
Per-layer subtotal61.0 M976 M~ 122 MB bf16
Embedding table263 M(128256, 2048)
Final norm2,048(2048,)
LM Head (vocab projection)263 M(128256, 2048)
Grand total≈ 1.50 B~ 3.0 GB bf16
+ +

Note: Llama-3.2-1B uses untied embeddings (LM Head is a separate parameter from the embedding table). That's why total is ~1.50 B not ~1.24 B if you sum just the published parameter count. The embedding table is loaded but the embedding lookup is a host-side numpy index, not an NPU kernel.

+ + +

A2. The transformer block — math and shapes

+ +

Llama-3.2-1B is just 16 of these blocks stacked, sandwiched between a token embedding lookup at the start and a final RMSNorm + LM Head at the end. (See A3 for the full top-level pipeline.)

+ +

One transformer block is a function block(x) → output where both x and output have the same shape [B, S, H]. The block has two sub-blocks (attention and FFN), each with a residual connection. We diagram them separately to keep each readable.

+ +

Symbol convention (used in every shape annotation below)

+ + + + + + + + + + + + + +
SymbolMeaningLlama-3.2-1B value
Bbatch size1 (this implementation is single-stream)
Ssequence length2048 (prefill) or 1 (decode)
Hhidden dim (d_model)2048
Lnumber of decoder layers16
N_hquery head count32
N_kvKV head count (GQA)8
GGQA group size = N_h / N_kv4
d_hper-head dim = H / N_h64
D_ffFFN intermediate dim8192
Vvocab size128256
+ +

Note: H = N_h · d_h = 32 · 64 = 2048, and the K/V projection output is N_kv · d_h = 8 · 64 = 512 (smaller than H because of GQA).

+ +
+
Linear / matmul / weight-bearing — Q/K/V/O proj, gate/up/down, embedding, LM head
+
Norm / activation / attention compute — RMSNorm, RoPE, SiLU, scaled dot-product attention
+
Data / structural — input/output tensors, residual adds
+
+ + +

A2.1 — Attention sub-block

+ +

From the block's input x, the attention sub-block produces an updated hidden state with cross-position information mixed in (causally — only earlier positions affect later ones). Three weighted projections (Q, K, V) plus RoPE, attention compute, and an output projection. The output is added to a saved copy of x (residual).

+ + + + + + + + + + + + Input x + [B, S, H] + + + + + save x for residual + + + + [B, S, H] + + + + + RMSNorm + γ: [H], row-wise on H + + + + + + + + [B, S, H] (broadcast to 3) + + + + + Q proj + W_q: [H, N_h·d_h] + + + + + K proj + W_k: [H, N_kv·d_h] + + + + + V proj + W_v: [H, N_kv·d_h] + + + + + + + [B, S, N_h·d_h] + [B, S, N_kv·d_h] + + + + + RoPE on Q + cos/sin LUT [S, d_h] + + + + + RoPE on K + cos/sin LUT [S, d_h] + + + + + V passthrough + no rotation + + + + + + + q_roped + k_roped + v + + + + + Scaled dot-product attention (causal, GQA) + S = softmax(Q · K^T / √d_h, causal_mask) · V + FlashAttention fuses softmax with the matmuls; GQA = each Q head shares a KV head + no learnable weights + + + + + [B, S, N_h·d_h] = [B, S, H] + + + + + Output projection + W_o: [N_h·d_h, H] + + + + + [B, S, H] + + + + + Residual add: out = x + proj + [B, S, H] + + + +

Per-kernel explanations (attention sub-block)

+ +
+RMSNorm (input normalization) +
    +
  • Shape: [B, S, H][B, S, H], weight γ: [H]
  • +
  • Op: y = x · rsqrt(mean(x², dim=-1) + ε) · γ
  • +
  • Application: row-wise on H. Each (b, s) position is normalized independently along the hidden dim. No mean subtraction (unlike LayerNorm), no bias. The mean is over 2048 elements per row.
  • +
+
+ +
+Q projection +
    +
  • Shape: [B, S, H][B, S, N_h·d_h] (= [B, S, H] since H = N_h · d_h), weight W_q: [H, N_h·d_h]
  • +
  • Op: Y = X @ W_q (no bias)
  • +
  • Application: per-token GEMM, contraction dim is H. Each (b, s) row maps independently; B · S can be flattened into the M dim for batching. In our impl: prefill is a GEMM at M=2048; decode is a GEMV at M=1.
  • +
+
+ +
+K projection +
    +
  • Shape: [B, S, H][B, S, N_kv·d_h], weight W_k: [H, N_kv·d_h]
  • +
  • Op: Y = X @ W_k (no bias)
  • +
  • Application: per-token GEMM with contraction dim H. The output dim is 4× smaller than Q because of GQA (only 8 KV heads vs 32 Q heads).
  • +
+
+ +
+V projection +
    +
  • Shape: [B, S, H][B, S, N_kv·d_h], weight W_v: [H, N_kv·d_h]
  • +
  • Op: Y = X @ W_v
  • +
  • Application: identical pattern to K projection. (Could be fused with K — but typically isn't because they're each large enough on their own.)
  • +
+
+ +
+RoPE on Q (Rotary Position Embedding) +
    +
  • Shape: [B, S, N_h, d_h][B, S, N_h, d_h] (unchanged), reads cos/sin LUT of shape [S, d_h]
  • +
  • Op: rotate each (b, s, h) head's d_h-vector by the angle determined by position s. Q_roped[b,s,h,i] = Q[b,s,h,i]·cos[s,i] − Q[b,s,h,i+d_h/2]·sin[s,i] (half-split convention)
  • +
  • Application: per-(position, head) elementwise rotation. The rotation angle is a deterministic function of position alone. The LUT is constant across calls (precomputed by generate_rope_lut). Pure data movement + multiplies; no reductions.
  • +
+
+ +
+RoPE on K +
    +
  • Shape: [B, S, N_kv, d_h][B, S, N_kv, d_h] (unchanged)
  • +
  • Op: identical to RoPE on Q but for K (smaller because only N_kv heads).
  • +
  • Application: per-(position, head) rotation. Same LUT shared with Q.
  • +
+
+ +
+V passthrough +
    +
  • Shape: [B, S, N_kv, d_h] unchanged
  • +
  • Op: none. V does not get RoPE-rotated (only Q and K do).
  • +
  • Application: conceptual node — V is just held until attention compute consumes it. No kernel.
  • +
+
+ +
+Scaled dot-product attention (causal, GQA) +
    +
  • Shape: q_roped: [B, S, N_h, d_h], k_roped: [B, S, N_kv, d_h], v: [B, S, N_kv, d_h]out: [B, S, N_h, d_h]
  • +
  • Op (5 sub-steps): +
      +
    1. Transpose K: for each head pair, K^T swaps the seq and d_h dims.
    2. +
    3. QK^T: scores[b,h,s,t] = Q[b,s,h,:] · K[b,t,h//G,:] / √d_h — note the GQA index h//G shares one KV head across G query heads.
    4. +
    5. Causal mask: set scores[b,h,s,t] = −∞ for t > s so query position s only attends to positions 0..s.
    6. +
    7. Softmax: P[b,h,s,t] = softmax(scores[b,h,s,:]) — normalized over the LAST dim (key positions). Row-wise per query.
    8. +
    9. Weighted sum of V: out[b,s,h,:] = Σ_t P[b,h,s,t] · V[b,t,h//G,:]
    10. +
    +
  • +
  • Application: quadratic in S (attention matrix is S × S). FlashAttention fuses all 5 sub-steps into a tiled kernel that never materializes the full S × S matrix in memory. No learnable weights. Memory-bound for large S, compute-bound for small S.
  • +
+
+ +
+Output projection +
    +
  • Shape: [B, S, H][B, S, H], weight W_o: [H, H]
  • +
  • Op: proj = attn_out @ W_o (no bias)
  • +
  • Application: per-token GEMM. Contraction over the head-flattened dim H = N_h · d_h.
  • +
+
+ +
+Residual add +
    +
  • Shape: x: [B, S, H] + proj: [B, S, H][B, S, H]
  • +
  • Op: res1 = x + proj
  • +
  • Application: pure elementwise. Adds the saved input x to the projection output. No reduction, no broadcast (both inputs same shape). Output is the input to the FFN sub-block.
  • +
+
+ + +

A2.2 — FFN sub-block (SwiGLU)

+ +

Takes the attention sub-block's output (call it res1) and applies a 3-projection feed-forward network with SwiGLU activation. Like the attention sub-block, the result is added to a saved copy of the input.

+ + + + + + + + + + + + Input res1 + [B, S, H] + + + + + save res1 for residual + + + + [B, S, H] + + + + + RMSNorm + γ: [H], row-wise on H + + + + + + + [B, S, H] (broadcast to 2) + + + + + Gate projection + W_gate: [H, D_ff] + + + + + Up projection + W_up: [H, D_ff] + + + + + + gate: [B, S, D_ff] + up: [B, S, D_ff] + + + + + SiLU(gate) + x · σ(x), elementwise + + + + + up (unchanged) + + + + + + + + + + Elementwise mul: SiLU(gate) ⊙ up + [B, S, D_ff], no reduction + + + + + + swiglu: [B, S, D_ff] + + + + + Down projection + W_down: [D_ff, H] + + + + + down: [B, S, H] + + + + + Residual add: out = res1 + down + [B, S, H] — block output + + + +

Per-kernel explanations (FFN sub-block)

+ +
+RMSNorm (FFN) +
    +
  • Shape: [B, S, H][B, S, H], weight γ: [H]
  • +
  • Op: same formula as the attention RMSNorm; uses a different learned γ (called ffn_norm).
  • +
  • Application: row-wise on H.
  • +
+
+ +
+Gate projection +
    +
  • Shape: [B, S, H][B, S, D_ff], weight W_gate: [H, D_ff]
  • +
  • Op: gate = X @ W_gate
  • +
  • Application: per-token GEMM. Expands hidden dim by 4× (2048 → 8192). One of the two compute-heavy GEMMs in the block.
  • +
+
+ +
+Up projection +
    +
  • Shape: [B, S, H][B, S, D_ff], weight W_up: [H, D_ff]
  • +
  • Op: up = X @ W_up
  • +
  • Application: identical pattern to Gate projection. Could be fused with Gate into one [H, 2·D_ff] GEMM (some implementations do this); ours keeps them separate.
  • +
+
+ +
+SiLU(gate) +
    +
  • Shape: [B, S, D_ff][B, S, D_ff] (unchanged)
  • +
  • Op: SiLU(x) = x · σ(x) = x / (1 + e^{−x})
  • +
  • Application: pure elementwise. No cross-axis dependency; each scalar is independent. Often fused with the elementwise multiply that follows.
  • +
+
+ +
+Elementwise multiply: SiLU(gate) ⊙ up +
    +
  • Shape: [B, S, D_ff] × [B, S, D_ff][B, S, D_ff]
  • +
  • Op: swiglu[i] = SiLU(gate[i]) · up[i] — Hadamard product.
  • +
  • Application: elementwise. In our codebase, SiLU and this multiply are fused into one C++ kernel (silu_and_mul.cc), saving one full pass over the 8192-wide tensor.
  • +
+
+ +
+Down projection +
    +
  • Shape: [B, S, D_ff][B, S, H], weight W_down: [D_ff, H]
  • +
  • Op: down = swiglu @ W_down
  • +
  • Application: per-token GEMM. Contracts over D_ff (8192) — this is the largest contraction dim in the model.
  • +
+
+ +
+Residual add (FFN) +
    +
  • Shape: res1: [B, S, H] + down: [B, S, H][B, S, H]
  • +
  • Op: out = res1 + down
  • +
  • Application: pure elementwise. Output is the block output → next layer's input.
  • +
+
+ + +

A2.3 — Block-level annotations

+ +
+
Compute-heavy ops (FLOPs ranking, prefill at S=2048)
+
+The three FFN GEMMs dominate FLOPs because D_ff is 4× larger than H. Per-block prefill FLOPs: +
    +
  • Gate proj: 2 · S · H · D_ff ≈ 2 · 2048 · 2048 · 8192 = 69 GFLOP
  • +
  • Up proj: same as gate ≈ 69 GFLOP
  • +
  • Down proj: 2 · S · D_ff · H ≈ 69 GFLOP
  • +
  • Q proj: 2 · S · H · H ≈ 17 GFLOP
  • +
  • K proj, V proj: each ≈ 4 GFLOP (smaller because of GQA)
  • +
  • O proj: 17 GFLOP
  • +
  • Attention compute: 4 · S² · H ≈ 34 GFLOP (dominated by S² scaling — biggest if S grew)
  • +
+The 3 FFN projections together = 207 GFLOP per layer ≈ 60% of per-layer compute. × 16 layers × 1.27 s prefill ≈ 2.6 TFLOP/s achieved on the NPU. +
+ +
Memory-bound ops (bandwidth-limited at small S)
+
+RMSNorm and the elementwise SwiGLU multiply have low arithmetic intensity (~1 FLOP/byte). Attention's softmax + the sub-multiplies inside FlashAttention also become memory-bound when S is small or d_h is small. In decode (S=1), everything except the GEMVs is memory-bound — this is why the per-token decode time is dominated by weight bandwidth, not FLOPs. +
+ +
Fusable kernel boundaries
+
+Common fusions seen in this and other implementations: +
    +
  • SiLU + elementwise multiply → one kernel (silu_and_mul.cc). Saved per-pass over the 8192-wide tensor.
  • +
  • Gate proj + Up proj → one big GEMM with output dim 2·D_ff (some implementations; ours doesn't currently).
  • +
  • FlashAttention fuses transpose + QK^T + mask + softmax + SV into one tiled kernel (this is exactly what makes "FA" different from naive attention).
  • +
  • RMSNorm + next GEMM can be fused with epilogue tricks; our impl does NOT fuse this (norm is its own sub-launch). Trade-off vs the multi-launch ELF approach.
  • +
+See ABLATION_STUDY.html for measurements of how much our specific multi-launch grouping helps. +
+ +
Convention gotchas (where this implementation differs from "vanilla" Llama)
+
+
    +
  • RoPE half-split vs interleaved. HuggingFace Llama (and our impl, via rope_halfsplit.cc) uses the half-split convention: (d[i], d[i + d_h/2]) are paired for rotation. llama.cpp and the original RoPE paper use interleaved (d[2i], d[2i+1]). The two produce DIFFERENT outputs for the same input — they are not interchangeable. Our LUT layout is [cos_0..cos_{d_h/2-1}, sin_0..sin_{d_h/2-1}] (concatenated, not interleaved), matching the half-split rotation.
  • +
  • Causal mask is implicit in FlashAttention. Our FA kernel takes causal=True and never materializes a mask matrix; it just skips attending to t > s in the inner loop.
  • +
  • RMSNorm has no bias. Unlike LayerNorm. Just x · rsqrt(mean(x²) + ε) · γ. ε is a small constant (1e-5 typically) for numerical stability.
  • +
  • No dropout at inference. (Only relevant at training.)
  • +
+
+ +
GQA effects on KV cache size
+
+With G = 4 (each KV head shared by 4 Q heads), the KV cache is 4× smaller than it would be without GQA. For Llama-3.2-1B at max_seq=2048: +
KV cache size = 2 · L · N_kv · max_seq · d_h · 2 bytes = 2 · 16 · 8 · 2048 · 64 · 2 = ~32 MB +
Without GQA (N_kv = N_h = 32), this would be ~128 MB. The savings matter much more for larger models / longer sequences. +
+ +
Weight sharing
+
+Llama-3.2-1B uses untied embeddings — the LM head W_lm is a separate parameter from the embedding table W_emb. (Some smaller models tie them to save parameters.) Both are [V, H]; together they account for ~526 M of the model's 1.5 B parameters. +
+
+ + +

A2.4 — Mapping back to our codebase

+ +

The 14 ops above map to the production NPU kernels as follows:

+ + + + + + + + + + + + + + +
Sub-blockModel opsNPU realization
AttentionRMSNorm + Q proj + K proj + V proj + RoPE Q + RoPE Krms_gemms_rope.elf — 6 sub-launches stitched into one ELF
Scaled dot-product attentionflash_attn.elf — 1 launch (separate ELF; un-mergeable)
(boundary)O proj + Residual #1First 2 sub-launches of o_ffn.elf
FFNRMSNorm + Gate proj + Up proj + SiLU·mul + Down proj + Residual #2Remaining 6 sub-launches of o_ffn.elf
+ +

So one transformer block = 3 NPU calls (rms_gemms_rope + flash_attn + o_ffn) wrapping a total of 15 sub-launches (6 + 1 + 8). The grouping is not the natural "attention sub-block / FFN sub-block" boundary — instead, the cut is "before FlashAttention" vs "after FlashAttention", because FA must be its own ELF (compile-time scaling issue documented in docs/explain.md). Why this exact grouping is best — and why all 15 sub-launches don't go into one ELF — is the topic of Part B and the ablation study.

+ +

One transformer block as math (paraphrased)

+ +

Below is one Llama-3.2-1B layer written as plain NumPy — useful as a reference for the math, independent of NPU plumbing. (The actual production NPU pipeline is described in Part B; numerical correctness is gated by make verify against HF transformers bf16 — see VERIFICATION.html.)

+ +
def transformer_block(x, lw, rope_lut, config):
+    # Attention sub-block
+    normed = rms_norm(x, lw.attn_norm)
+    q = normed @ lw.wq
+    k = normed @ lw.wk
+    v = normed @ lw.wv
+    q_roped = apply_rope(q, rope_lut)
+    k_roped = apply_rope(k, rope_lut)
+    attn_out = attention(q_roped, k_roped, v, config)   # GQA, causal mask
+    res1 = x + attn_out @ lw.wo
+
+    # FFN sub-block
+    normed2 = rms_norm(res1, lw.ffn_norm)
+    gate = normed2 @ lw.w_gate
+    up = normed2 @ lw.w_up
+    swiglu_out = silu(gate) * up
+    output = res1 + swiglu_out @ lw.w_down
+    return output
+ + +

A3. Full forward pass — what one inference call does

+ +

Top-level pipeline

+ +

The diagram below shows the whole inference call as 6 stages. The decoder block is collapsed (×L) — its internals are diagrammed in A2.

+ + + + + + + + + + + + Input token IDs + [B, S] + + + + [B, S] (integer indices) + + + + + Token embedding + W_emb: [V, H] + + + + [B, S, H] + + + + + Decoder block × L + attention + FFN (with residuals) + L = 16 layers, each = 14 ops (see A2) + writes K, V to KV cache (see A4) + + + + [B, S, H] + + + + + Final RMSNorm + γ: [H], row-wise on H + + + + [B, S, H] + + + + + LM head + W_lm: [V, H], untied + + + + [B, S, V] logits + + + + + argmax over V + at last real-token row + + + + + + + + next_token_id ∈ [0, V) + + + +

Per-stage explanations (top-level pipeline)

+ +
+Token embedding +
    +
  • Shape: [B, S] integer indices → [B, S, H] bf16, weight W_emb: [V, H]
  • +
  • Op: x[b, s, :] = W_emb[token_ids[b, s], :] (table lookup)
  • +
  • Application: per-token gather. No matmul — just numpy fancy-indexing on the host (cheap; the embedding table is large but each lookup reads only H bf16 values per token). Done on CPU in our impl, not on NPU.
  • +
+
+ +
+Decoder block × L +
    +
  • Shape: [B, S, H][B, S, H] per block, repeated L times
  • +
  • Op: x ← block_i(x, layer_weights[i], rope_lut) for i in 0..L-1
  • +
  • Application: sequential dependency between layers (output of layer i is input to layer i+1). Within each layer, ops are mostly per-token; only attention crosses positions (causally). See A2 for the 14-op breakdown.
  • +
  • Side effect: each layer's K and V (after RoPE) are also written to the KV cache for use in decode. See A4.
  • +
+
+ +
+Final RMSNorm +
    +
  • Shape: [B, S, H][B, S, H], weight γ_final: [H]
  • +
  • Op: same RMSNorm formula as inside the blocks; uses a different learned γ (called final_norm).
  • +
  • Application: row-wise on H. In our impl this is computed on CPU because we only need the result at one row (see A7).
  • +
+
+ +
+LM head +
    +
  • Shape: [B, S, H][B, S, V], weight W_lm: [V, H] (untied — separate from W_emb)
  • +
  • Op: logits = X @ W_lm.T (no bias)
  • +
  • Application: per-token GEMM, contraction over H, output dim is V (128256 — the largest output dim in the model). In our impl: only one row is computed (the row at pred_pos), as a 1×V GEMV partitioned 8 ways. See A7 for why this is sufficient.
  • +
+
+ +
+argmax over V +
    +
  • Shape: [B, S, V][B, S] integer indices
  • +
  • Op: next_token = argmax(logits, dim=-1)
  • +
  • Application: per-row reduction. We only argmax the row at pred_pos to get the next token. CPU operation in our impl (cheap — V=128256 single argmax).
  • +
+
+ +

The two operating modes (model-level)

+ +

The forward pass above works for ANY input length. But there are two common usage patterns:

+ + + + + + + + + + + + + + + + + +
ModeInputWhat we doOutputCost
PrefillThe full prompt: token_ids of length S = prompt_lenOne forward pass with seq=S. Save K, V at every layer for every position into a "KV cache" — we'll need them for decode. Argmax at position S-1 gives the first generated token.1 token + populated KV cache~1.27 s for S=2048
DecodeOne token at a time: x of shape (1, 2048) — embedding of the previous output tokenOne forward pass with seq=1. Use the KV cache in attention — the new K, V for this position get appended. Argmax gives the next token.1 new token + KV cache extended by 1 position~92 ms per token
+ +

To generate N tokens of text from a prompt: 1 prefill call + N decode calls. The KV cache is built once during prefill and grows by one row per decode step.

+ + +

A4. KV cache — what it is, why we need it, how it grows

+ +

The problem

+ +

For a sequence of length T, attention computes:

+ +
Q = X @ Wq    # shape (T, n_heads, head_dim)
+K = X @ Wk    # shape (T, n_kv_heads, head_dim)
+V = X @ Wv    # shape (T, n_kv_heads, head_dim)
+attn = softmax(Q @ K.T / √d) @ V   # causal masked
+ +

During decode, position T+1 only adds one new query Q[T+1]. But that query needs to attend to all previous K[0..T] and V[0..T]. If we threw those away after the prefill and recomputed them, we'd redo O(T) work per decode step.

+ +

The solution: cache K and V

+ +

Once K[i] and V[i] are computed for any position i, they never change again (they only depend on x[i] and weights, not on later tokens). So we store them in a per-layer cache and append a new entry per decode step.

+ +

Memory layout in our codebase

+ +

Allocated in llama32_1b_inference.py:369:

+ +
k_cache = np.zeros(
+    (config.n_layers, n_kv_heads, max_seq, head_dim),
+    dtype=bfloat16,
+)
+v_cache = np.zeros((config.n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16)
+ + + + + + + +
DimensionSizeWhy
n_layers16Each layer has its own K, V (different transformations of x)
n_kv_heads8GQA — only 8 distinct heads (vs 32 Q heads)
max_seqprompt_len + n_tokensEnough room for the prompt + every generated token
head_dim64Per-head dimension
+ +

Total memory: 16 × 8 × max_seq × 64 × 2 bytes = 16,384 × max_seq bytes ≈ 32 MB at max_seq=2048. Tiny compared to the 3 GB of weights — KV cache is not a memory concern for Llama-1B.

+ +

Visual: how the K/V cache grows

+ +

Showing one layer's K cache (the V cache has the same structure). Each cell is one position; rows are the 8 KV heads.

+ +

State after prefill (prompt_len = 7 tokens, max_seq = 20 in this toy example):

+ +
+
↓ kv_head_idx (8 rows). → position 0, 1, 2, ... 19
+
+ +
+
Populated by prefill (real prompt position)
+
Allocated but empty (zero)
+
+
+ +

State after 4 decode steps (current_pos = 11):

+ +
+
+ +
+
Prefill positions (0..6)
+
Decode positions (7..10)
+
Future positions (11..19, not yet written)
+
+
+ +

The key code points

+ +

(1) Cache allocation — once per generate() call:

+ +
# llama32_1b_inference.py:369
+k_cache = np.zeros((n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16)
+v_cache = np.zeros((n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16)
+ +

(2) Prefill writes to the cache — extracts k_roped and v from each layer's intermediates:

+ +
# llama32_1b_inference.py:401 — runs after each layer in the prefill loop
+k_cache[layer_idx, :, :seq_len, :] = (
+    k_roped.astype(bfloat16)
+    .reshape(seq_len, n_kv_heads, head_dim)
+    .transpose(1, 0, 2)        # layout: (n_kv_heads, seq_len, head_dim)
+)
+v_cache[layer_idx, :, :seq_len, :] = (
+    v_raw.astype(bfloat16).reshape(seq_len, n_kv_heads, head_dim).transpose(1, 0, 2)
+)
+ +

(3) Decode appends to the cache and reads from it — inside decode_attention_cpu and run_decode_block:

+ +
# llama32_1b_decode.py — paraphrased
+def run_decode_block(x, lw, cache, config, k_cache_layer, v_cache_layer, current_pos, ...):
+    # 1. Compute new k, v from this token (NPU rms_gemv_rope call)
+    out = cache.load_and_run("rms_gemv_rope", ...)
+    new_k_roped = out[12]   # shape (kv_dim,) = (512,) flat
+    new_v       = out[8]    # shape (kv_dim,)
+
+    # 2. Append to cache at current_pos
+    k_cache_layer[:, current_pos] = new_k_roped.reshape and transpose
+    v_cache_layer[:, current_pos] = new_v.reshape and transpose
+
+    # 3. CPU attention reads positions 0..current_pos
+    attn_out = decode_attention_cpu(q_roped, k_cache_layer, v_cache_layer,
+                                     current_pos, n_heads, n_kv_heads, head_dim)
+
+# Inside decode_attention_cpu:
+seq_len = current_pos + 1
+k_cached = k_cache[:, :seq_len, :]    # only positions 0..current_pos
+v_cached = v_cache[:, :seq_len, :]
+# Then standard QKᵀ V softmax against this slice...
+ +
+ Important sequencing detail: at the start of decode, current_pos = prompt_len (NOT 0). The cache positions 0..prompt_len-1 are populated by the prefill. The first decode step writes the new k, v at position prompt_len and reads positions 0..prompt_len for attention (the new entry plus all the prefill entries). +
+ + +

A5. Padding to fixed seq_len + finding the real prompt

+ +

This implementation uses fixed seq_len=2048 because NPU kernels are compiled for one specific shape — recompiling for every prompt length would be prohibitive. So we always pad shorter prompts up to 2048. Let's trace exactly how that works.

+ +

Step 1 — Tokenization (host, CPU)

+ +

In llama32_1b_inference.py:731:

+ +
def _tokenize_prompt(session, prompt_text):
+    if session.model_variant == "instruct":
+        messages = [{"role": "user", "content": prompt_text}]
+        chat_text = session.tokenizer.apply_chat_template(messages, tokenize=False,
+                                                            add_generation_prompt=True)
+        return session.tokenizer.encode(chat_text)
+    return session.tokenizer.encode(prompt_text)
+ +

For "What is the capital of France?" with the instruct model, this returns ~30 tokens (the chat template adds system/user role markers).

+ +

Step 2 — Padding to seq_len

+ +

In llama32_1b_inference.py:754 (run_once):

+ +
tokens = _tokenize_prompt(session, prompt_text)   # length = real prompt
+prompt_len_actual = len(tokens)                  # save the real length
+if len(tokens) < session.seq_len:
+    tokens = tokens + [session.tokenizer.eos_token_id] * (session.seq_len - len(tokens))
+# Now len(tokens) == 2048 always.
+ +

So if the real prompt is 30 tokens long, tokens becomes [real_0, real_1, ..., real_29, EOS, EOS, ..., EOS] with 2018 EOS tokens of padding.

+ +

Step 3 — Recovering the real prompt length inside prefill

+ +

The prefill function doesn't receive prompt_len_actual directly — it gets only the padded token_ids array. It recovers the real length by counting non-EOS tokens (llama32_1b_inference.py:422):

+ +
prompt_len = len([t for t in token_ids if t != tokenizer.eos_token_id])
+pred_pos = prompt_len - 1     # index of the last real prompt token
+ +
+ Caveat: this assumes the real prompt does NOT contain any EOS tokens. For typical text inputs that's true. The instruct chat template uses <|begin_of_text|>, <|start_header_id|>, etc. — none of which are EOS — so this works in practice. If a prompt legitimately contained EOS, this counting would be wrong. +
+ +

Step 4 — Prefill processes ALL 2048 positions but only reads pred_pos's logits

+ +

The NPU runs the full forward pass over all 2048 positions including the EOS padding. The padding positions produce garbage k, v values. But we only use the logits at pred_pos = prompt_len - 1, which is BEFORE any padding (llama32_1b_inference.py:427):

+ +
# Final RMSNorm + LM Head — only on the last real-token row
+last_hidden = np.asarray(x_bf16, dtype=np.float32)[pred_pos:pred_pos + 1]
+last_normed_bf16 = _rms_norm(last_hidden, weights.final_norm).flatten().astype(bfloat16)
+
+# NPU LM Head GEMV (8 partitions) on the single normalized row
+results = decode_cache.load_and_run("lm_head_gemv", ...)
+logits_row = np.concatenate(results, axis=0)[:vocab_size]
+prefill_token = int(np.argmax(logits_row))
+ +

This is one of the production optimizations: instead of running the LM Head GEMM on all 2048 positions and then taking row pred_pos, we extract just that one row first (CPU RMSNorm in <1 ms) and run a 1×128256 GEMV on the NPU. Saves ~150 ms of pointless compute.

+ +

Step 5 — KV cache for decode uses prompt_len, not seq_len

+ +

After prefill, the KV cache has positions 0..2047 populated, but only positions 0..prompt_len-1 contain MEANINGFUL k/v (the rest are garbage from EOS padding). Decode starts at current_pos = prompt_len (llama32_1b_inference.py:573):

+ +
generated_tokens = [prefill_token]
+current_pos = prompt_len            # skip past the garbage padding positions
+x_decode = weights.embed_table[prefill_token].astype(bfloat16)
+
+for token_idx in range(n_tokens):
+    # Run all 16 transformer blocks in decode mode
+    for layer_idx in range(config.n_layers):
+        x = run_decode_block(x, ..., k_cache[layer_idx], v_cache[layer_idx],
+                              current_pos, ...)
+    # LM Head GEMV → next token
+    # ...
+    current_pos += 1            # cache grows by 1 per token
+ +

Inside decode_attention_cpu, the slicing k_cache[:, :current_pos+1, :] ensures we only attend to real prefill positions + actually-decoded positions. The garbage at indices prompt_len..2047 (left over from prefill processing the EOS padding) is never read — those slots are reused by decode if it generates enough tokens to overwrite them.

+ +

Cost of padding

+ +

For a 30-token prompt padded to 2048, the prefill compute does 2048 / 30 ≈ 68× more work than necessary, because every layer processes 2018 padding positions whose results we throw away. This is a deliberate tradeoff: fixed-shape kernels are vastly easier to compile and faster per-position than dynamic-shape kernels would be on this hardware.

+ +

Decode doesn't suffer from this — each decode call only processes ONE token (seq=1), and that token is the real new one.

+ +

Visual summary of the prompt+padding+decode lifecycle

+ +
+
Token IDs in the seq=2048 input array, then growing into decode positions:
+
+ +
+
Real prompt (positions 0..6, prompt_len=7)
+
EOS padding (E) — prefill processes but we ignore the output
+
Decode-generated tokens (current_pos=7,8,9,10,11)
+
+
In a real run with seq_len=2048, the EOS pad band would be 30 → 2048 positions wide. The decode positions start at index 30 (prompt_len) regardless of where the padding ended.
+
+ +
+ Note: the prefill's output token (at pred_pos = prompt_len - 1 = 6) is the FIRST generated token. It becomes generated_tokens[0]. Then decode generates tokens 1, 2, 3, ... and writes their k/v at cache positions prompt_len, prompt_len+1, .... The cache positions don't move; the cache just grows in-place into the previously-allocated max_seq array. +
+ + +

A6. Does padding affect the math at real positions?

+ +

Short answer: No. The hidden state at pred_pos = prompt_len − 1 is bit-identical to what you'd get if you ran with seq=prompt_len instead of seq=2048. (Same bytes, not just same logits.) This is why padding-with-EOS is a sound workaround, not a numerical approximation.

+ +

The reason: of the 14 ops in a transformer block (Part A2), only attention crosses positions. All other ops are per-position: each output row depends ONLY on its own input row. So the only path by which a padding position could contaminate pred_pos's output is through attention — and attention is causally masked, so pred_pos never sees positions later than itself. EOS padding tokens are by construction at indices ≥ prompt_len = pred_pos + 1, all of which the causal mask blocks.

+ +

Per-op analysis: which ops cross positions?

+ +

Let x[i] denote the hidden state at position i. For each op, the question is: does the output at position pred_pos depend on any x[j] with j ≠ pred_pos?

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
OpMathCross-position?Why / why not
Embedding lookupx[i] = embed_table[token_ids[i]]NoPer-token table lookup. Position i depends only on token_ids[i].
RMSNormx[i] · rsqrt(mean(x[i]²)+ε) · wNoThe mean is over the embedding dimension (2048 elements of one row), NOT over positions. RMSNorm at position i depends only on x[i]. Easy to verify: the norm formula has no sum across positions.
Q/K/V projectionQ[i] = x[i] @ Wq (etc.)NoA matmul (seq, emb) @ (emb, out) is independent matmul per row. Q[i] = x[i] @ Wq.
RoPErotate Q[i] by angle θ(i) from LUTNoRoPE rotates each (position, head) pair by an angle that is a function of position alone. Q_roped[i] depends only on Q[i] and the constant LUT[i].
Attentionout[i] = softmax(Q[i] · Kᵀ / √d, mask) · VYes — but maskedThe ONLY cross-position op. With the causal mask, out[i] attends to positions 0..i ONLY. Position pred_pos attends to 0..pred_pos — strictly before any padding. Padding positions are at indices pred_pos+1..2047, all blocked.
O projectionproj[i] = attn_out[i] @ WoNoPer-row matmul.
Residual addres[i] = x[i] + proj[i]NoElementwise per row.
FFN RMSNormsame as aboveNoPer-row.
Gate / Up GEMMsper-row matmulNoPer-row.
SwiGLUSiLU(gate[i]) * up[i]NoElementwise per row.
Down GEMMper-row matmulNoPer-row.
Residual add #2elementwise per rowNoPer-row.
Final RMSNormper-rowNoPer-row.
LM Headlogits[i] = x[i] @ W_lm.TNoPer-row matmul. (And we only compute row pred_pos — see A7.)
+ +
+ The single-point invariant: attention is the only op that mixes positions, and the causal mask guarantees that the mixing only flows EARLIER → LATER, never the reverse. Since EOS padding is appended at positions LATER than pred_pos, no padding position can leak into pred_pos's output through any pathway. +
+ +

What about the padding positions themselves?

+ +

The padding positions DO produce garbage output. EOS embeddings get RMSNormed, projected, RoPE-rotated, and run through attention (which can attend to real tokens earlier in the sequence — so the garbage is "garbage with prompt context"). But we never USE that garbage:

+ +
    +
  • LM Head logits: only computed at pred_pos (see A7), so padding-position logits don't exist.
  • +
  • KV cache for decode: the cache slots at indices prompt_len..2047 are written with garbage K/V from the padding positions. Decode skips them — it starts at current_pos = prompt_len and only reads cache slices 0..current_pos+1, never touching the garbage region. (Visualized in A4 and A5.)
  • +
  • Layer N+1's x_in at padding positions: this gets passed to the next transformer block, where it again produces garbage. Wasted compute, but causally walled off from pred_pos.
  • +
+ +

Subtle case: do dropout, layer norm running stats, etc. matter?

+ +

No, because:

+
    +
  • Dropout is not used at inference time.
  • +
  • RMSNorm has no running statistics (unlike BatchNorm — RMSNorm is purely per-row at inference; no batch statistics to corrupt).
  • +
  • FlashAttention's softmax normalizes per-row (per-query-position) — the denominator at row pred_pos sums over only positions 0..pred_pos due to the causal mask. Padding positions don't enter the sum.
  • +
+ +

How to verify this claim

+ +

You can prove the bit-identity empirically: run prefill on a 30-token prompt padded to 2048, then run prefill on the same 30 tokens with seq_len=30 (no padding) — assuming you have kernels compiled for seq=30, which production doesn't but the CPU reference does. Compare x_bf16[pred_pos] from both runs. They should be byte-equal.

+ +

This is something you have to script yourself if you ever need to re-prove it (make diagnosis probes the NPU vs HF bf16 per-layer cosine — see VERIFICATION.html — but it does not directly compare seq=30 vs seq=2048 padded).

+ + +

A7. Single-row LM Head GEMV — workaround or general optimization?

+ +

Short answer: general optimization. Always sufficient for autoregressive single-stream generation, regardless of padding. Even a real seq=2048 prompt with no padding would only need the logits at the last position to generate the next token.

+ +

Why this is true

+ +

Autoregressive language generation has a one-step lookahead: given hidden states for positions 0..T−1, the next token's distribution depends only on logits[T−1]. The logits at positions 0..T−2 would tell you "if I had sampled here, what would the next token be?" — but you've already committed to the actual tokens at those positions (they're the prompt). You don't re-sample them.

+ +

So the LM Head's job during inference is always the same: project ONE hidden state row (the last position's) into vocab space, argmax (or sample), produce ONE next token.

+ +

Where multi-row LM Head WOULD be needed

+ + + + + + + + + + + + + + + + + + + + + + + +
Use caseWhy multi-row?Used in this implementation?
Training (computing cross-entropy loss against teacher-forced labels)Loss is summed over all positions; need logits everywhereNo — this is inference-only
Speculative decoding (verify a draft model's K-token speculation)Need logits at K positions to score the speculationNo — single-stream sampling only
Beam search (track top-K candidate sequences)Need full distributions at each step for multiple beamsNo — greedy argmax (1 stream)
Dumping logits for analysis / probingResearcher wants per-position logits for downstream analysisNo
+ +

For the standard autoregressive sampling that this implementation does (greedy or top-k), you only need the last position's logits. This optimization holds whether your prompt fits in 30 tokens or 2048 tokens.

+ +

The math savings

+ + + + + + + + + + + + + +
ApproachComputeWhy
Naive: full-seq LM Head(2048, 2048) @ (2048, 128256) = (2048, 128256) ≈ 1 TFLOPComputes 2047 rows you'll never look at
This implementation: single-row GEMV(1, 2048) @ (2048, 128256) = (1, 128256) ≈ 0.5 GFLOPOnly the row you need; ~2000× less compute
+ +

In wall time, this is the "Saves ~150 ms" optimization mentioned in profile.md. Implemented at llama32_1b_inference.py:425-446: extract the single hidden-state row, do RMSNorm on it (CPU, <1 ms because it's one row of 2048 elements), then call the decode-side lm_head_gemv.elf on that single row. The same ELF is reused for both prefill's last-token projection and per-token decode — they're the same operation (1×128256 GEMV).

+ +

Padding workaround vs production-grade variable-length support

+ +

Now to your bigger question: what's the difference between this implementation's padding-with-EOS and what a real production inference server does?

+ +

Our approach is the simplest possible: compile kernels for one fixed shape (seq=2048), pad shorter prompts with EOS. This is appropriate for a research prototype on novel hardware where building a dynamic-shape compiler is itself a research problem.

+ +

Production inference servers (vLLM, TensorRT-LLM, SGLang, llama.cpp, etc.) use much more sophisticated approaches:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TechniqueWhat it doesThis implementation?Why production needs it
Dynamic-shape kernelsSame kernel handles any seq length, branching at runtime on shapeNo — fixed seq=2048Avoids waste on short prompts; supports any prompt length up to a limit
Chunked prefillSplit a long prompt (e.g., 32K tokens) into chunks of fixed size (e.g., 512), process sequentially with attention reading the cache for earlier chunksNo — single-shot at seq=2048; longer prompts unsupportedSupports prompts longer than the kernel's max seq length
Continuous batchingPack multiple users' requests into one batch; add new requests / remove finished ones every stepNo — single user, single streamMaximize GPU/NPU utilization with multiple concurrent users
Paged KV cacheKV cache split into fixed-size pages (like virtual memory pages); attention gathers them at runtimeNo — contiguous (n_layers, n_kv_heads, max_seq, head_dim) arrayAvoids fragmentation and overcommit when serving many users with variable sequence lengths
Speculative decodingUse a small draft model to speculate K tokens, verify in one big-model forward passNo — vanilla autoregressive~2-3× decode speedup at the cost of ~10-30% extra compute
Quantization (INT8/INT4)Compress weights to lower precision, dequantize in kernelNo — bf16 throughout~2-4× speedup, ~2-4× memory reduction
Multi-node tensor/pipeline parallelismShard model across multiple devicesNo — single NPURequired for models larger than one device's memory
+ +

What our implementation IS vs IS NOT

+ +
+ What this is: a single-user, single-stream, fixed-seq-length, bf16, single-NPU autoregressive LLM inference reference. Optimized for clean code, hardware bring-up, and meaningful end-to-end performance numbers (1.27 s prefill / 92 ms/token decode at seq=2048). Demonstrates that NPU2 + MLIR-AIR can run a real LLM end-to-end. +
+ +
+ What this isn't: a production inference server. To deploy this in production, you'd want chunked prefill (or at least multiple compiled seq lengths to avoid the padding waste on short prompts), continuous batching (for multi-user serving), paged KV cache (for memory efficiency), and quantization (for further speedup). The padding workaround is appropriate for the research artifact; it would be replaced with proper variable-length support in a productionization pass. +
+ +

The "single-row LM Head" optimization is general; the "padding-to-2048" optimization is specific

+ +

To return to your distinction: these are two completely separate things.

+ + + + + + + + + + + + + +
OptimizationAlways applicable?Why
Single-row LM Head GEMV at the end of prefillYes, always. Production servers do this too.Autoregressive sampling only needs the last row's logits, regardless of how the prompt was processed.
Pad short prompts with EOS to 2048No — specific to fixed-shape kernels. Production usually avoids this.It wastes compute (~68× for a 30-token prompt). Only acceptable when dynamic-shape kernels would be even more expensive (e.g., due to compile time, runtime branching cost, or tooling immaturity).
+ +

So when you read the LM Head GEMV code, don't think "this is a workaround". Think "this is the right thing to do, and it happens to also dodge an extra 2047 wasted rows that the padding would have created if we used the full-seq GEMM here".

+ + +

Part B — How we run it on the NPU

+ +

Part A was the model. Now we look at how this codebase realizes those ops on AMD NPU2. The translation is not 1-to-1: the model has 14 ops per layer; production runs them as 3 NPU kernel calls per layer (rms_gemms_rope = ops 1-6, flash_attn = op 7, o_ffn = ops 8-15). That's the "multi-launch merging" optimization at work.

+ +

B1. End-to-end runtime flow

+ +

Implementation overview — prefill

+ +

One inference's prefill phase: from the input prompt to the first generated token. The diagram shows which steps run on CPU (gray, host-side numpy) vs which run on NPU (purple, stitched ELFs). FA is its own ELF (pink-purple); the per-layer triple (rms_gemms_rope.elf, flash_attn.elf, o_ffn.elf) is grouped inside the "decoder block × 16" container. KV cache extraction happens on the host after each layer.

+ + + + + + + + + + + + Prompt → tokenize + pad + CPU; output [B, S=2048] (EOS-padded) + + + + + + + Token embedding lookup + CPU numpy gather; W_emb: [V, H] + + + x: [B, S, H] = [1, 2048, 2048] bf16 + + + + Decoder block × L = 16 (one iteration shown; loop wraps back) + + + + + rms_gemms_rope.elf — NPU, 1 xrt.run + 6 stitched launches: RMSNorm + Q/K/V GEMM + RoPE Q + RoPE K + + + q_roped [S, H]; k_roped [S, kv_H]; v [S, kv_H] + + + + + flash_attn.elf — NPU, 1 xrt.run (separate ELF) + 1 launch; un-mergeable (see B5) + + + attn_out [S, H] + + + + extract k_roped, v + + + + KV cache write + CPU; k_cache[L,:,:S], v_cache[L,:,:S] + + + + + + o_ffn.elf — NPU, 1 xrt.run + 8 stitched launches: O + Add + RMSNorm + Gate/Up + SwiGLU + Down + Add + + + x_next [S, H] (= next layer's x_in) + + + (loop back to rms_gemms_rope for layer L+1) + + + + x: [B, S, H] after 16 layers + + + + + Final RMSNorm at row pred_pos + CPU; only 1 row (see A7); → [1, H] + + + + + + + lm_head_gemv.elf — NPU, 1 xrt.run + 8 stitched partitions; W_lm: [V, H] sliced + + + logits [1, V] = [1, 128256] + + + + + argmax → next_token_id + CPU; first generated token + + + + + next_token_id ∈ [0, V) + + + +

Read the colors: gray = CPU/host (numpy, embedding lookup, KV cache management, argmax), purple = NPU stitched ELF, pink = NPU FlashAttention (always its own ELF, never stitched — see B3). The dashed purple outline marks the 16-layer loop boundary.

+ +

Implementation overview — decode (per token)

+ +

Decode generates ONE token per pass. Per layer it makes 2 NPU calls + 1 CPU step (because attention runs on CPU during decode — see B9 for why). The KV cache is read+appended on each layer.

+ + + + + + + + + + + + Previous token id + scalar (from prefill or prior decode step) + + + + + + + Token embedding lookup + CPU numpy gather; single row of W_emb + + + x_decode: [H] = [2048] bf16 (single token) + + + + Decoder block × L = 16 (one iteration shown; loop wraps back) + + + + + rms_gemv_rope.elf — NPU, 1 xrt.run + 6 stitched launches (GEMV variants of prefill kernels) + + + q_roped [H]; k_roped [kv_H]; v [kv_H] — single-token + + + + + decode_attention_cpu — CPU + reads k/v_cache[L, :, 0:current_pos]; writes new k/v at current_pos + + + attn_out [H] + + + + read 0..pos, + append at pos + + + + KV cache + [16, kv_h, max_seq, d_h] + + + + + + o_gemv_ffn.elf — NPU, 1 xrt.run + 8 stitched launches (GEMV variants of o_ffn) + + + (loop back to rms_gemv_rope for layer L+1) + + + + x: [H] after 16 layers + + + + Final RMSNorm + CPU; single-row, <1 ms; → [1, H] + + + + + + + lm_head_gemv.elf — NPU, 1 xrt.run + SAME ELF reused from prefill (8 partitions) + + + logits [1, V] + + + + + argmax → next_token_id + CPU; → loop back as input to next decode step + + + + + next_token_id ∈ [0, V) + + + +

NPU calls per pass — concrete count

+ + + + + +
PhaseNPU calls per layerNPU calls totalCPU work per layer
Prefill (1 pass, 16 layers)3 (rms_gemms_rope + flash_attn + o_ffn)48 + 1 (lm_head_gemv) = 49KV cache write (numpy slice assign)
Decode (1 token, 16 layers)2 (rms_gemv_rope + o_gemv_ffn)32 + 1 (lm_head_gemv) = 33decode_attention_cpu (single-query GQA against KV cache)
+ +

NPU2 tile array — context

+ +

NPU2 (AMD Strix, AIE2P architecture) has a 32-tile compute array arranged as 8 columns × 4 rows. Plus 8 mem-tiles (L2) and shim tiles for DMA. Each compute tile is a VLIW vector core with its own L1 SRAM. Different kernels use different subsets of the 32 tiles depending on parallelism strategy:

+ + + + + + + +
Herd shapeTiles usedUsed for (typical)
[8, 4]32 / 32 (full)Prefill GEMMs (Q/K/V/O/Gate/Up/Down). M-dim split 8 ways × N-dim split 4 ways.
[8, 1]8 / 32RMSNorm, RoPE (prefill), SwiGLU, eltwise add, GEMV (decode). Row-parallel across one column of tiles.
[1, 1]1 / 32RoPE (decode) — single tile is enough for the tiny single-token rotation.
Cascade [c_nq, c_ns]variesFlashAttention — uses an internal segment + cascade-stages design (4 stages × per-head segments). Hard to give one number; FA stresses the array more than any other single ELF.
+ +

Each kernel's exact tile usage is listed in B2's per-kernel cards. The choice of herd shape is made by the Python builder (passed as herd_x / herd_m / herd_n kwargs) and locked at compile time — it can't change between calls of the same ELF.

+ +

The 4 phases of llama32_1b_inference.py:main

+ +

From make run to printed output:

+ +
+
+

Phase 1: build_session llama32_1b_inference.py:669

+

One-time setup: create KernelCache instances, compile (or load cached) all ELFs, load model weights from HuggingFace, build the RoPE LUT, call prepare_runtime.

+ +

Phase 2: prepare_runtime llama32_1b_inference.py:129

+

Pre-loads ALL weights for ALL 16 layers into per-layer NPU Buffer Objects (BOs), so subsequent inference calls only need to write activations. This is the single biggest cost-amortization in the pipeline (see B7).

+ +

Phase 3: run_once / generate llama32_1b_inference.py:742, 523

+

Tokenize the prompt → pad to seq_len=2048 (see Part A5) → call run_npu_prefill → enter the decode loop.

+ +

Phase 4: decode/print

+

For instruct models, apply chat template; emit tokens incrementally via the streaming callback in interactive mode.

+
+
+

Make targets Makefile:78-99

+
# One-time compile (~3 min)
+make compile
+
+# Run inference
+make run
+make run PROMPT="..."
+
+# With profiling breakdown
+make profile
+
+# Top-k token-level correctness gate vs HF transformers bf16
+make verify
+
+# Per-layer ffn_out cosine vs HF bf16 (informational)
+make diagnosis
+
+# Interactive REPL
+make chat
+
+
+ + +

B2. The kernel building blocks

+ +

Before discussing optimizations (multi-launch ELF stitching, BO management), let's see what the basic units are. The codebase has 7 unique compute kernels that together implement every model op from Part A. Each kernel is one of two implementation patterns:

+ + + + + + + + + + + + + +
PatternHow it worksUsed for
MLIR-only (codegen)The Python builder constructs an MLIR module that describes the operation in the linalg / scf / air dialects. aircc + aiecc lower it to AIE-tile instructions through standard linalg-vectorize and AIR placement passes. Peano compiles the resulting per-tile LLVM IR. No hand-written C++.RMSNorm, GEMM, eltwise add
MLIR + external C++ kernelThe MLIR module declares func.func private @kernel_name { link_with = "kernel.o" } and calls it from inside an air.herd. The .o is a hand-written C++ kernel compiled separately by Peano (LLVM-AIE). aiecc links the .o into the per-tile ELFs.GEMV, RoPE, SwiGLU, FlashAttention
+ +

External C++ is used when a hand-tuned implementation beats codegen — typically for kernels with non-trivial vectorization patterns, double-buffering, or tile-level fused operations (FA's softmax + MMA fusion is the canonical example).

+ +

The compile pipeline (one ELF, regardless of pattern)

+ +
+
Python
builder
+
+
MLIR
module
+
+
aircc
(AIR passes)
+
+
aiecc
(AIE passes)
+
+
Per-tile
ELFs (Peano)
+
+
.elf
+ .insts.bin
+
+ +

For external-C++ kernels, the .o file is compiled by Peano in advance (see kernel_builder/external_kernels.py) and placed in the build directory before aircc runs; aiecc finds it via the link_with attribute when packaging per-tile ELFs.

+ +

The whole pipeline is invoked by XRTBackend.compile(mlir_module) inside KernelCache.compile_and_cache — see kernel_builder/cache.py:251. (B3 covers stitching multiple kernels into one ELF; this section is just the per-kernel building blocks.)

+ +

The 7 kernels — quick index

+ + + + + + + + + + +
KernelPatternMaps to model op (Part A)Source builderExternal C++ (if any)
RMSNormMLIR-onlyRMSNorm (attn-norm, ffn-norm, final-norm)weighted_rms_norm/weighted_rms_norm.py
GEMMMLIR-onlyQ/K/V/O proj, Gate/Up/Down proj (prefill, S=2048)kernel_builder/gemm_builder.py
GEMVMLIR + C++Q/K/V/O proj, Gate/Up/Down proj (decode, S=1); LM Headmatrix_vector_multiplication/bf16/matvec.pymv.ccmv.o + mv_k8192.o
RoPEMLIR + C++RoPE Q, RoPE Krope_lut/rope_lut.pykernel_builder/rope_halfsplit.ccrope.o
SwiGLUMLIR + C++SiLU(gate) ⊙ up — fusedkernel_builder/ffn_swiglu/silu_and_mul.pykernel_builder/ffn_swiglu/silu_and_mul.ccsilu_and_mul.o
FlashAttentionMLIR + C++Scaled dot-product attention (causal, GQA)flash_attention/kernel_fusion_based/attn_npu2_seqfirst.pyflash_attention/kernel_fusion_based/attn_npu2.ccattn.o
Eltwise AddMLIR-onlyResidual add #1, Residual add #2eltwise_add/eltwise_add.py
+ +

External-C++ .o compilation is centralized in kernel_builder/external_kernels.py, which uses Peano (LLVM-AIE, found via $PEANO_INSTALL_DIR) with --target=aie2p-none-unknown-elf -O2 -std=c++20. Each function (compile_silu_and_mul, compile_rope, etc.) checks if the .o already exists and skips if so.

+ + +

B2.1 — RMSNorm

+ + + + + + + +
Source builderprogramming_examples/weighted_rms_norm/weighted_rms_norm.py
External C++None — pure MLIR/codegen
Maps to model opRMSNorm (Part A2 op #1, #10; final norm in Part A3)
Production usageInside rms_gemms_rope.elf + o_ffn.elf (prefill); rms_gemv_rope.elf + o_gemv_ffn.elf (decode); the final RMSNorm at the end of inference is computed on CPU instead (single row only — see A7)
NPU compute tile usageherd [8, 1] = 8 of 32 tiles. One column of 8 tiles, each tile reducing across one slice of rows. Same shape used in both prefill and decode (the per-row reduction doesn't benefit from row-direction parallelism beyond the column count).
+ +

How it's compiled. The Python builder uses FuncOp.from_py_func + @herd to construct an air.herd that does the per-row reduction (sum-of-squares), then the rsqrt + multiply. There's no external C++ — aircc lowers the linalg/scf/arith ops to AIE-tile vector intrinsics, and Peano then turns the per-tile LLVM IR into AIE2P machine code.

+ +

The op: y[i] = x[i] · rsqrt(mean(x[i]², dim=-1) + ε) · γ per row. γ (the learned scale) is a per-feature [H]-shaped weight broadcast across rows. The implementation tiles the row dim across an herd_x-tile-tall herd; each tile reduces and normalizes its rows.

+ +

Quirk: the builder produces a bare air.herd (not wrapped in air.launch). When stitched into a multi-launch ELF, the stitching code wraps it in air.launch { air.segment { herd } } via _wrap_ir_in_launch from kernel_builder/stitching.py. (See B5 for why this wrapping is needed.)

+ + +

B2.2 — GEMM (matrix-matrix multiply, prefill)

+ + + + + + + + +
Source builderprogramming_examples/llama32_1b/kernel_builder/gemm_builder.py (function _build_gemm_module(m, k, n, ...)) — thin wrapper around the upstream BF16 GEMM
Wrapsprogramming_examples/matrix_multiplication/bf16/run.py (function build_module(m, k, n, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd_n, np_dtype_in, np_dtype_out, arch, direct_codegen)) — the generic BF16 GEMM module builder shared with the standalone GEMM example
External C++None — codegen via aircc's linalg.matmul lowering
Maps to model opsQ proj, K proj, V proj, O proj, Gate proj, Up proj, Down proj (Part A2 ops #2-#4, #8, #11-#12, #14) — during prefill only, where S=2048 makes a true matrix-matrix GEMM
Production usagerms_gemms_rope.elf contains 3 GEMMs (Q, K, V); o_ffn.elf contains 4 GEMMs (O, Gate, Up, Down)
NPU compute tile usageherd [8, 4] = 32 of 32 tiles. Production sets herd_m=8, herd_n=4 — the herd's M dim (8) parallelizes output-row tiles and the N dim (4) parallelizes output-col tiles. This is the only kernel that uses the full NPU2 compute array. Configured per-GEMM in rms_gemms_rope_multi.py:200-209 and o_ffn_multi.py:182-202.
+ +

Relationship to the upstream programming_examples GEMM. There is NOT a separate Llama-specific GEMM kernel. gemm_builder.py is a 30-line wrapper that:

+
    +
  1. Calls the upstream build_module from programming_examples/matrix_multiplication/bf16/run.py with bfloat16 input AND output, arch="aie2p" (NPU2), and direct_codegen=True. This produces a base MLIR module containing one air.herd wrapping a tiled linalg.matmul.
  2. +
  3. Applies an extra transform IR script (the ~100-line GEMM_TRANSFORM_IR string in gemm_builder.py) on top of that module. The transform script does additional tiling, herd-vectorization, vector-contract → f32 cast lifting, and several rounds of cast-pair hoisting that move arith.extf / arith.truncf ops out of the innermost loops.
  4. +
+ +

Without the transform-IR step, the GEMM compiles but the inner-loop quality is significantly worse (extra bf16↔f32 conversions per MMA iteration). The transform script is what makes the production GEMM competitive with hand-written kernels — but the actual linalg.matmul tiling structure comes from the shared upstream builder, not from the wrapper.

+ +

Tile config (prefill default). The wrapper accepts tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd_n. Production uses different configs per GEMM (smaller L2 tiles for the small Q/K/V/O 2048-emb GEMMs, larger for the wider Gate/Up/Down 8192-D_ff GEMMs). All configs come from multi_launch_builder/rms_gemms_rope_multi.py:200-209 and multi_launch_builder/o_ffn_multi.py:182-202.

+ +

Why no external C++. The aircc + aiecc pipeline can lower a tiled linalg.matmul with the right transform IR to the same AIE MMA intrinsic that a hand-written kernel would use. There's no measurable win from hand-rolling the matmul C++.

+ + +

B2.3 — GEMV (matrix-vector multiply, decode)

+ + + + + + + +
Source builderprogramming_examples/matrix_vector_multiplication/bf16/matvec.py (function build_module(M, K, tile_m, m_input, herd_m, ...))
External C++programming_examples/matrix_vector_multiplication/bf16/mv.cc → compiled to mv.o (and mv_k8192.o, see below)
Maps to model opsQ/K/V/O/Gate/Up/Down projections — during decode (S=1 makes it M=1 GEMV); also the LM Head (which is structurally a 1×V GEMV regardless of phase, see A7)
Production usagerms_gemv_rope.elf contains 3 GEMVs (Q, K, V); o_gemv_ffn.elf contains 4 GEMVs (O, Gate, Up, Down); lm_head_gemv.elf is an 8-partition GEMV stitched 8 times
NPU compute tile usageherd [8, 1] = 8 of 32 tiles. Production sets tile_m=8, m_input=4, herd_m=8 — the herd's 8 tiles parallelize the M output dim. With M=1 (S=1 in decode) the GEMV gets ZERO M-direction parallelism within a single tile — the 8 tiles instead each handle a slice of the output rows of the projection. The Down GEMV (K=8192) uses a renamed mv_k8192.o variant with tile_m=2 but the same 8-tile herd shape.
+ +

How it's compiled. The MLIR builder constructs an air.launch wrapping an air.herd whose body calls the C++ kernel @matvec_vectorized_bf16_bf16 (declared private with link_with = "mv.o"). The C++ in mv.cc implements a hand-vectorized y = W @ x using AIE bf16 MMA intrinsics. Peano compiles this to a .o file via kernel_builder/external_kernels.py:compile_mv:

+ +
def compile_mv(tile_m=8):
+    src = _PROJ_ROOT / "matrix_vector_multiplication" / "bf16" / "mv.cc"
+    _compile_kernel(src, "mv.o", extra_flags=[f"-DDIM_M_OUTPUT={tile_m}"])
+ +

The mv_k8192.o trick. The decode o_gemv_ffn.elf needs TWO GEMV variants in one ELF: K=2048 (for O/Gate/Up/normal slots) and K=8192 (for the Down GEMV). MLIR can't have two private functions with the same name and different signatures — so the same mv.cc source is compiled a SECOND time with renamed entry points via -D macros (see kernel_builder/external_kernels.py:155):

+ +
def compile_mv_k8192():
+    _compile_kernel(src, "mv_k8192.o", extra_flags=[
+        "-DDIM_M_OUTPUT=2",
+        "-Dmatvec_vectorized_bf16_bf16=dg_matvec_vectorized_bf16_bf16",  # renamed
+        "-Dlinalg_fill_bf16=dg_linalg_fill_bf16",
+    ])
+ +

The renamed function appears in the merged ELF as a separate symbol, side-by-side with the K=2048 version.

+ + +

B2.4 — RoPE (Rotary Position Embedding)

+ + + + + + + +
Source builderprogramming_examples/rope_lut/rope_lut.py (decode/per-row); for prefill multi_launch_builder/rms_gemms_rope_multi.py:_build_rope_2d wraps it for 2D inputs
External C++programming_examples/llama32_1b/kernel_builder/rope_halfsplit.cc → compiled to rope.o
Maps to model opRoPE Q, RoPE K (Part A2 ops #5, #6)
Production usagerms_gemms_rope.elf + rms_gemv_rope.elf (one RoPE for Q-side, one for K-side per ELF)
NPU compute tile usagePrefill: herd [8, 1] = 8 of 32 tiles (rope_herd_x=8, herd_y=1 in rms_gemms_rope_multi.py; the 8 tiles split the seq dim S=2048 across rows). Decode: herd [1, 1] = 1 of 32 tiles (rope_herd_x=1 in rms_gemv_rope_multi.py; only one row to rotate, so single-tile is sufficient and avoids DMA fan-out overhead).
+ +

How it's compiled. The MLIR builder constructs an air.herd that DMA-loads one row of (cos, sin) LUT plus one row of input data into L1, then calls @rope (declared with link_with = "rope.o"). The C++ in rope_halfsplit.cc implements the per-position rotation.

+ +

The rope_halfsplit.cc story. Two RoPE conventions exist:

+
    +
  • Half-split (used by HuggingFace Llama and our impl): pair (d[i], d[i + d_h/2]) for rotation. LUT layout: [cos_0, ..., cos_{d_h/2-1}, sin_0, ..., sin_{d_h/2-1}].
  • +
  • Interleaved (used by llama.cpp and the original RoPE paper): pair (d[2i], d[2i+1]). LUT layout: [cos_0, sin_0, cos_1, sin_1, ...].
  • +
+

Mixing the two produces wrong outputs. The upstream aie_kernels/aie2p/rope.cc uses the interleaved convention. Llama-3.2-1B needs half-split, so this codebase has its own rope_halfsplit.cc compiled to the same rope.o filename → drop-in replacement, no MLIR changes needed. See kernel_builder/external_kernels.py:119 (compile_rope):

+ +
def compile_rope():
+    src = Path(__file__).resolve().parent / "rope_halfsplit.cc"   # NOT the upstream rope.cc
+    _compile_kernel(src, "rope.o")
+ +

The LUT (cos/sin table) is precomputed once per session by generate_rope_lut in llama32_1b_weights.py and passed as a kernel input — not compiled into the kernel.

+ + +

B2.5 — SwiGLU (silu_and_mul, fused activation)

+ + + + + + + +
Source builderprogramming_examples/llama32_1b/kernel_builder/ffn_swiglu/silu_and_mul.py
External C++programming_examples/llama32_1b/kernel_builder/ffn_swiglu/silu_and_mul.cc → compiled to silu_and_mul.o
Maps to model opsSiLU(gate) + elementwise multiply (Part A2 ops #13 — fused into one kernel)
Production usageo_ffn.elf + o_gemv_ffn.elf (one fused SwiGLU step between gate/up GEMMs and down GEMM)
NPU compute tile usageherd [8, 1] = 8 of 32 tiles (swiglu_herd_x=8, swiglu_herd_y=1). The 8 tiles split the elementwise work across the row dim. SiLU+multiply is memory-bound at this scale — adding more tiles wouldn't help because L2/L1 DMA bandwidth is already saturated.
+ +

How it's compiled. The MLIR builder constructs an air.herd that takes the gate and up tensors as inputs (each [B, S, D_ff]) and produces one output tensor. The herd body calls @silu_and_mul_bf16 (declared with link_with = "silu_and_mul.o"). The C++ implementation does out[i] = SiLU(gate[i]) · up[i] in a vectorized inner loop using AIE bf16 SiLU + multiply intrinsics — fusing the two ops eliminates one full pass over the 8192-wide tensor (vs. doing SiLU and the multiply as two separate kernels).

+ +

Compile (with extra include for utils header): see kernel_builder/external_kernels.py:106 (compile_silu_and_mul):

+ +
def compile_silu_and_mul():
+    src = _PROJ_ROOT / "llama32_1b" / "kernel_builder" / "ffn_swiglu" / "silu_and_mul.cc"
+    include_dir = _get_aie_include_dir()
+    utils_header = Path(include_dir) / "aie_kernels" / "aie_kernel_utils.h"
+    extra = []
+    if utils_header.exists():
+        extra = ["-include", str(utils_header)]
+    _compile_kernel(src, "silu_and_mul.o", extra_flags=extra)
+ + +

B2.6 — FlashAttention

+ + + + + + + +
Source builderprogramming_examples/flash_attention/kernel_fusion_based/attn_npu2_seqfirst.py (function build_module(lk, lkp, lq, lqp, dk, dv, num_q_tiles, num_cascade_stages, num_heads, num_kv_heads, causal))
External C++programming_examples/flash_attention/kernel_fusion_based/attn_npu2.cc → compiled to attn_npu2.o (also copied to attn.o)
Maps to model opScaled dot-product attention (Part A2 op #7) with causal mask + GQA
Production usageflash_attn.elf — its OWN ELF, never stitched with rms_gemms_rope or o_ffn (un-mergeable, see B5)
NPU compute tile usageCascade design — uses ~16-24 tiles depending on config. Production sets num_q_tiles=4, num_cascade_stages=4, num_heads_per_unroll=2. The kernel uses MULTIPLE air.segments (sized [num_heads_per_unroll, 1]) each containing a herd sizes=[c_nq, c_ns]. Effectively the cascade pipelines Q-tile streaming across stages — different from the single-herd pattern of the other 6 kernels. Decode reuses prefill's flash_attn.elf only for full-prefill recomputation (rare); the per-token decode attention runs on CPU instead.
+ +

How it's compiled. Of all 7 kernels, FlashAttention is by far the most complex. The MLIR builder produces a multi-tile cascade of air.herds that stream Q tiles through K/V tiles using air.channels for inter-tile DMA. The actual softmax + MMA fusion is in C++ (attn_npu2.cc), which exposes ~16 functions for the FA tile primitives (Q tile load, K tile load, dot-product, online softmax update, V multiply-accumulate, rescale, etc.).

+ +

Many compile-time flags. See kernel_builder/external_kernels.py:130 (compile_attn_npu2):

+ +
def compile_attn_npu2(head_dim=64):
+    src = _PROJ_ROOT / "flash_attention" / "kernel_fusion_based" / "attn_npu2.cc"
+    _compile_kernel(src, "attn_npu2.o", extra_flags=[
+        "-DBIT_WIDTH=8",
+        f"-Dlqp={head_dim}",        # Q-per-tile
+        f"-Dlkp={head_dim}",        # K-per-tile
+        f"-Ddk={head_dim}",         # head dim, K side
+        f"-Ddk_full={head_dim}",
+        f"-Ddv={head_dim}",         # head dim, V side
+        f"-Ddv_full={head_dim}",
+        "-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16",
+        "-DROUND_CONV_EVEN",
+    ])
+    # Some link_with attrs use "attn.o", so make a copy
+    shutil.copy2("attn_npu2.o", "attn.o")
+ +

Most of these -D flags are head_dim parameters that the C++ uses to size internal tile buffers at compile time. head_dim=64 for Llama-3.2-1B; the same kernel works for Llama-3.2-3B with head_dim=128.

+ +

Why this can't go in a multi-launch ELF. The cascade design uses many air.channels and stresses the air-opt-shim-dma-bds compiler pass quadratically. With 9+ launches (i.e., FA + the rms_gemms_rope launches) in one ELF, this pass takes >10 minutes. So FA stays as its own single-launch ELF and is invoked between rms_gemms_rope and o_ffn from the host (see B5). This is the main reason production has 3 NPU calls per layer instead of 1.

+ + +

B2.7 — Eltwise Add (residual)

+ + + + + + + +
Source builderprogramming_examples/eltwise_add/eltwise_add.py; specialized 2D and 2D→1D variants are defined locally in multi_launch_builder/o_ffn_multi.py (_build_add_2d_to_2d, _build_add_2d_to_1d)
External C++None — pure MLIR/codegen
Maps to model opResidual #1 (after attention), Residual #2 (after FFN) (Part A2 ops #9, #15)
Production usageTwo adds inside o_ffn.elf (one for each residual); two analogous adds inside o_gemv_ffn.elf
NPU compute tile usageherd [8, 1] = 8 of 32 tiles. The 8 tiles split the row dim. Pure DMA-bound: the add itself is one cycle per element, so total time = DDR↔L1 transfer time. More tiles wouldn't help.
+ +

How it's compiled. The simplest kernel: an air.herd with a tiled elementwise loop, lowered by aircc to the AIE add intrinsic. The 2D and 2D→1D variants exist because the residual outputs may be consumed as flat 1D arrays by the next sub-launch (e.g., the final o_ffn output is 1D n_total = seq*emb); the variant just calls memref.collapse_shape internally to handle the type mismatch.

+ +

Quirk: like RMSNorm, the simple add builder produces a bare air.herd; multi-launch stitching wraps it via _wrap_ir_in_launch.

+ + +

B2.8 — Compile-time helpers and orchestration

+ +

Two files coordinate the actual external-C++ compilation:

+ + + + + +
FileWhat it does
kernel_builder/external_kernels.pyPer-kernel compile_* functions (one per .o) + a compile_all_external_kernels(head_dim) top-level that runs all 5 (silu_and_mul, rope, attn, mv, mv_k8192). Each uses Peano via $PEANO_INSTALL_DIR/bin/clang++. Skips compilation if the .o already exists.
kernel_builder/cache.py:prepare_air_projectCalled from compile_and_cache before each ELF compile. Cleans air_project/, calls compile_all_external_kernels, then copies all .o files into air_project/ where aiecc's link_with search path will find them.
+ +

So the flow for compiling one ELF is: prepare_air_project → external C++ .o files exist in air_project/backend.compile(mlir_module) runs aircc + aiecc, which links the .os into the per-tile ELFs → output .elf + .insts.bin are copied into cache_dir/.

+ +
+ Bottom line on the building blocks: 7 unique compute kernels. Three are MLIR-only codegen (RMSNorm, GEMM, eltwise add) and four are MLIR + hand-written C++ linked via Peano-compiled .o files (GEMV, RoPE, SwiGLU, FlashAttention). A single ELF can contain one or many of these — see B5 for stitching. +
+ +

Tile-mapping summary

+ +

Side-by-side view of how each of the 7 kernels maps onto the NPU2 8×4 compute array:

+ + + + + + + + + + + +
KernelPhaseHerd shapeTilesWhy this shape
RMSNormBoth[8, 1]8Per-row reduction; 8-tile column splits rows
GEMMPrefill[8, 4]32Full 2D output-tile parallelism (M and N)
GEMVDecode[8, 1]8M=1 forces output-row-only parallelism
RoPEPrefill[8, 1]8S=2048 rows split across 8 tiles
RoPEDecode[1, 1]1Only 1 row to rotate; multi-tile would just add fan-out overhead
SwiGLUBoth[8, 1]8Memory-bound; more tiles wouldn't help
Eltwise AddBoth[8, 1]8DMA-bound; 1-cycle add
FlashAttentionPrefillcascade [c_nq, c_ns]~16-24Multi-segment Q-tile cascade pipeline
+ +

Observation: only the prefill GEMM uses the entire 32-tile array. Most kernels use 8 tiles (one column) — they are limited by either the reduction structure (RMSNorm) or by DMA bandwidth (SwiGLU, eltwise add). For decode, the loss of M-direction parallelism (M=1) means there is simply no work for the additional column dim, so even GEMV drops to 8 tiles. Implication: the M=1 decode path leaves 24/32 = 75% of the compute array idle on every dispatch, which is one reason the per-token throughput is dispatch-overhead-bound (see ablation Plan 0).

+ + +

B3. From standalone kernels to end-to-end inference — the four gaps

+ +

B2 covered each kernel as a standalone unit — what it computes, how it's compiled, and how many tiles it uses. But you cannot just chain those 7 kernels together and get a working 1.27 s prefill. Several practical problems sit between "I have a working RMSNorm kernel" and "I have a 16-layer transformer running on the NPU":

+ + + + + + + + + + + + + + + + + + + + + + + + + + + +
GapProblem if unsolvedSolutionSection
#1 — Layout matchingKernel A's output shape/layout doesn't match what kernel B expects to read. Naive chaining produces wrong values or silently misaligned data.CPU pre-transpose of weights, free MLIR reshapes, deliberate physical KV-cache transpose on the host side, mv_k8192 macro-rename trick.B4
#2 — XRT dispatch overheadEach xrt.run() call has ~100 µs fixed overhead. With 49 kernels per prefill pass × 16 layers, dispatch alone would dominate runtime.Stitch multiple air.launchs into one ELF so 6-8 logical kernels run from a single xrt.run() call. Intermediates flow via DDR, host stays out of the loop.B5
#3 — Per-call BO managementNaive flow re-allocates and re-uploads every kernel argument on every call. A 14 MB weight tensor uploaded per kernel call would dominate the ~30 ms-per-call budget.Allocate XRT Buffer Objects once, classify each arg as static (write-once), intermediate (no host transfer at all), or output (host-readable). Skip everything that hasn't changed.B6
#4 — Compile time + per-layer stateEach ELF compile takes ~30-50 s. Recompiling on every script start costs 3+ minutes. Also: 16 layers × 6 ELFs × N weights each → which BO holds which layer's weights?KernelCache persists compiled ELFs to disk, caches loaded XRT contexts in process, and maintains per-layer BO sets keyed by bo_key="rms_gemms_rope_L{layer}".B7
+ +

Sections B4-B7 cover each gap one at a time. Once they're all in place, the prefill (B8) and decode (B9) detail sections show the four gaps working together on real per-layer code paths. B10 is the final code map.

+ +
+ Why this ordering matters. Each gap solution depends on understanding the previous one: layout decisions (B4) constrain what can be stitched into one ELF (B5); the stitched ELF's input layout determines BO classification (B6); BO classification determines what KernelCache needs to track per layer (B7). Skipping ahead leaves you with isolated tricks; reading in order shows why each was necessary. +
+ +
+ Want to know how much each gap contributes? See the companion ABLATION_STUDY.html for a controlled 4-cell measurement that quantifies the marginal speedup from each gap, separately for decode (Plan 0) and prefill (Plan 1). Spoiler: the dominant optimization flips between phases. +
+ + +

B4. Gap #1 — Layout matching between kernels

+ +

The 7 building-block kernels were each developed in their own standalone programming_examples demo. Their input/output layouts were chosen for that demo's convenience — not for chaining into a transformer. Several layout mismatches show up the moment you try to feed one kernel's output into another:

+ +

Mismatch #1 — Weight matrix orientation (GEMV)

+ +

HuggingFace stores Llama weights as (out_features, in_features): e.g. wq has shape (2048, 2048) with the FIRST dim being the output. The standalone GEMV kernel, however, expects A[M, K] with M=output, K=input — but reads A contiguously in K-major order (last dim is the contiguous one). HuggingFace storage is output-major. Naive use → reading the wrong elements per MMA, silent garbage output.

+ +

Fix: CPU pre-transpose every decode-side weight matrix once, before any timing starts. Implemented in llama32_1b_inference.py:171-197 inside prepare_runtime:

+ +
# Pre-transpose all decode GEMV weights (one-time, before timing)
+for lw in weights.layers:
+    lw._wq_t   = np.ascontiguousarray(lw.wq.astype(bfloat16).reshape(emb_dim, emb_dim).T)
+    lw._wk_t   = np.ascontiguousarray(lw.wk.astype(bfloat16).reshape(emb_dim, kv_dim).T)
+    lw._wv_t   = np.ascontiguousarray(lw.wv.astype(bfloat16).reshape(emb_dim, kv_dim).T)
+    lw._wo_t   = np.ascontiguousarray(lw.wo.astype(bfloat16).reshape(emb_dim, emb_dim).T)
+    lw._wgate_t = np.ascontiguousarray(lw.w_gate.astype(bfloat16).reshape(emb_dim, hidden_dim).T)
+    lw._wup_t   = np.ascontiguousarray(lw.w_up.astype(bfloat16).reshape(emb_dim, hidden_dim).T)
+    lw._wdown_t = np.ascontiguousarray(lw.w_down.astype(bfloat16).reshape(hidden_dim, emb_dim).T)
+ +

The .T + ascontiguousarray physically reorders the weight matrix bytes in DDR so the GEMV kernel reads them in K-major order naturally. This costs ~50 ms per layer × 16 layers ≈ 800 ms ONCE at startup, then never again — the transposed buffers live on as _wq_t, _wk_t, etc. and get uploaded to NPU BOs during weight preload.

+ +

Why CPU and not on the NPU? The NPU DMA engine has stride=1 mandatory for sub-32-bit types (it can't do a strided BF16 DMA). Doing the transpose during DMA-in would require shape rearrangement that the DMA hardware refuses. So the transpose lives in numpy on the CPU.

+ +

Mismatch #2 — KV cache layout (prefill ↔ FlashAttention ↔ decode)

+ +

The same physical KV tensor is touched by three different consumers, each with its own preferred layout:

+ + + + + + + +
ConsumerWants layout
RoPE K kernel output (prefill)[seq, n_kv_heads, head_dim] — sequence-major
FlashAttention input (prefill)[seq, n_kv_heads, head_dim] — sequence-major (matches RoPE)
KV cache storage (host)[n_kv_heads, max_seq, head_dim] — head-major (so per-head slicing is contiguous)
Decode CPU attention (per-token reads)[n_kv_heads, current_pos+1, head_dim] — needs head-major for fast per-head dot-products
+ +

Solution: the prefill kernels keep the seq-major layout that RoPE produces (so RoPE→FlashAttention has a free zero-cost layout match), and the host transposes once after each layer's prefill output to populate the head-major KV cache. From llama32_1b_inference.py:401-410:

+ +
k_cache[layer_idx, :, :seq_len, :] = (
+    intermediates["k_roped"]
+        .astype(bfloat16)
+        .reshape(seq_len, n_kv_heads, head_dim)
+        .transpose(1, 0, 2)        # seq-major → head-major
+)
+v_cache[layer_idx, :, :seq_len, :] = (
+    intermediates["v"].astype(bfloat16)
+        .reshape(seq_len, n_kv_heads, head_dim)
+        .transpose(1, 0, 2)
+)
+ +

This transpose runs on the CPU (~1 ms per layer) for the same DMA-stride reason as Mismatch #1. The bf16 stride=1 hardware limit means you cannot do a layout transpose during NPU DMA-out; the host has to materialize the head-major view itself. (See BF16 DMA stride limitation note in project docs.)

+ +

Mismatch #3 — GEMM output flat shape vs. RoPE multi-head input

+ +

Q/K GEMM emits [seq, n_heads * head_dim] as a flat 2D tensor. RoPE expects [seq, n_heads, head_dim] so it can apply the per-(head, dim/2) rotation. This one is FREE — it's a pure shape view, no data movement. The MLIR builder uses memref.expand_shape on the L2 buffer between the GEMM air.launch and the RoPE air.launch inside the same stitched ELF (no DDR round-trip, no DMA reshape). Same trick at the eltwise-add → next-RMSNorm boundary.

+ +

Mismatch #4 — FFN flat output for the next layer

+ +

o_ffn.elf's final output (after the second residual add) is shaped [seq, emb] as far as the math cares, but the next layer's rms_gemms_rope.elf wants its input as a flat 1D [seq * emb] buffer (because that's how the leading RMSNorm's L2 tile shape was specified). The eltwise-add kernel gained a _build_add_2d_to_1d variant that calls memref.collapse_shape internally so the producer and consumer agree on a flat 1D buffer. See multi_launch_builder/o_ffn_multi.py.

+ +

Mismatch #5 — Two GEMV variants in one ELF (K=2048 and K=8192)

+ +

The decode o_gemv_ffn.elf contains FOUR GEMVs: O, Gate, Up, and Down. Three of them have K=2048 (the embedding dim); the Down GEMV alone has K=8192 (the FFN hidden dim, accumulating back to embedding). MLIR can't have two private functions with the same name and different signatures in one module.

+ +

Solution (from kernel_builder/external_kernels.py:155): compile mv.cc a SECOND time with macro renames, producing a separate symbol for the K=8192 variant:

+ +
def compile_mv_k8192():
+    _compile_kernel(src, "mv_k8192.o", extra_flags=[
+        "-DDIM_M_OUTPUT=2",
+        "-Dmatvec_vectorized_bf16_bf16=dg_matvec_vectorized_bf16_bf16",  # renamed
+        "-Dlinalg_fill_bf16=dg_linalg_fill_bf16",
+    ])
+ +

Both .o files end up in air_project/ at link time. The MLIR module references each one by its (renamed) symbol, and the linker happily places both into the same ELF.

+ +
+ Bottom line on layout matching: three of the five mismatches are fixed by FREE MLIR reshapes inside stitched ELFs (zero-cost, no data movement). Two require physical CPU work — both are forced by the AIE DMA's stride=1 limitation on sub-32-bit types, which prevents an NPU-side bf16 transpose. Total CPU layout cost: ~800 ms one-time at startup (weight pre-transpose) plus ~1 ms × 16 layers ≈ 16 ms per prefill pass (KV cache transpose). Both are completely outside the timed prefill loop or rounded into negligible cost. +
+ + +

B5. Gap #2 — Multi-launch ELF stitching

+ +

The problem. Each xrt.run() call has fixed dispatch overhead (kernel-handle lookup, host↔device synchronization) of ~100 µs. With 7 kernels per layer × 16 layers = 112 NPU calls per prefill pass, dispatch alone is ~11 ms — small relative to a 1.2 s prefill, but devastating for decode where each kernel does only hundreds of µs of NPU work. For decode, raw dispatch overhead can rival the actual compute time.

+ +

The fix. Combine multiple kernels into one ELF that runs in one xrt.run() call. The host issues one dispatch; intermediates flow between sub-kernels via DDR using NPU DMA, with no host involvement. From the host's view, "rms_gemms_rope" looks like one kernel even though it's really 6 stitched air.launchs back-to-back.

+ +

The mechanism

+ +

An MLIR module can contain multiple air.launch operations inside a single func.func. Each air.launch wraps an air.segment wrapping air.herd(s) — i.e., one logical kernel. When that combined module is compiled to one ELF and invoked by one xrt.run(), the launches execute sequentially and intermediates flow between them via DDR using NPU DMA — without CPU involvement.

+ +

The Python builders in multi_launch_builder/*_multi.py do this stitching. They take individual MLIR modules (from B2's per-kernel builders) as text strings and concatenate the function bodies into one combined func, with SSA values renamed to avoid collisions.

+ +

The 6 production ELFs (stitched products)

+ +

The production code stitches the 7 kernel building blocks from B2 into 6 ELFs:

+ + + + + + + + + +
ELFPhaseStitched kernelsBuilderCompile time
rms_gemms_rope.elfPrefill6: RMSNorm + Q GEMM + K GEMM + V GEMM + RoPE Q + RoPE Kmulti_launch_builder/rms_gemms_rope_multi.py:193~33 s
flash_attn.elfPrefill1: FlashAttentionflash_attention/.../attn_npu2_seqfirst.py~46 s
o_ffn.elfPrefill8: O GEMM + Add + RMSNorm + Gate GEMM + Up GEMM + SwiGLU + Down GEMM + Addmulti_launch_builder/o_ffn_multi.py:178~50 s
rms_gemv_rope.elfDecode6: RMSNorm + Q/K/V GEMV + RoPE Q + RoPE K (GEMV variants)multi_launch_builder/rms_gemv_rope_multi.py:369~3 s
o_gemv_ffn.elfDecode8: O GEMV + Add + RMSNorm + Gate/Up GEMV + SwiGLU + Down GEMV + Add (GEMV variants)multi_launch_builder/o_gemv_ffn_multi.py~7 s
lm_head_gemv.elfBoth8: identical 8-partition GEMV stitched 8 timesmulti_launch_builder/lm_head_gemv_multi.py~13 s
+ +

So one prefill layer = 3 NPU calls (rms_gemms_rope + flash_attn + o_ffn) covering 15 sub-launches. Without stitching it would be 15 NPU calls per layer × 16 layers = 240 calls per prefill. With stitching it's 48 calls per prefill (16 × 3).

+ +

Why FlashAttention is its own ELF (un-mergeable)

+ +

FA's MLIR uses many air.channels for its cascade-of-tiles design. The air-opt-shim-dma-bds compiler pass scales super-linearly with the number of channels in a module. With 9+ stitched launches in one ELF (i.e., FA + the rms_gemms_rope launches), this pass takes >10 minutes — empirically prohibitive. So the production split is: FA stays as a 1-launch ELF, called between the stitched rms_gemms_rope and o_ffn. That's why one prefill layer is 3 NPU calls, not 1.

+ +

How stitching works (text-based)

+ +

All in kernel_builder/stitching.py as text-manipulation utilities. No MLIR Python API for moving operations between modules — every operation belongs to a Context, and you can't lift a region from one func and graft it into another. Text-based stitching sidesteps this.

+ +

The algorithm:

+
    +
  1. Build each sub-kernel as its own complete MLIR module (using B2's per-kernel builders).
  2. +
  3. Extract each module's func.func body (just the operations between signature and return).
  4. +
  5. Rename all SSA values, affine maps, and symbols with a unique prefix to avoid collisions.
  6. +
  7. Remap the original %argN references to the combined function's arg indices (this is what threads the data flow between launches).
  8. +
  9. Concatenate all bodies into one combined func, surrounded by combined affine map declarations and external function decls.
  10. +
  11. Parse the resulting text with mlir.ir.Module.parse(...) to validate.
  12. +
+ +

Concrete example: how rms_gemms_rope is stitched

+ +
# multi_launch_builder/rms_gemms_rope_multi.py:466-481 (paraphrased)
+bodies, maps_all = [], []
+for ir, prefix, arg_map in [
+    (rms_ir,    "r",  {0:0, 1:1, 2:2}),       # RMSNorm: x_in, norm_w, normed
+    (q_ir,      "q",  {0:2, 1:3, 2:4}),       # Q GEMM: normed (=arg2), wq (=arg3), q (=arg4)
+    (k_ir,      "k",  {0:2, 1:5, 2:6}),       # K GEMM: normed, wk (=arg5), k (=arg6)
+    (v_ir,      "v",  {0:2, 1:7, 2:8}),       # V GEMM: normed, wv (=arg7), v (=arg8)
+    (rope_q_ir, "rq", {0:4, 1:9, 2:11}),      # RoPE Q: q (=arg4), lut_q (=arg9), q_roped (=arg11)
+    (rope_k_ir, "rk", {0:6, 1:10, 2:12}),     # RoPE K: k (=arg6), lut_k (=arg10), k_roped (=arg12)
+]:
+    body = _extract_between_func_and_return(ir)
+    maps = _extract_affine_maps(ir)
+    body = _rename_all_with_externs(body, prefix, _EXTERN_FUNCS)  # prefix all SSA
+    maps = [_rename_all_with_externs(m, prefix, _EXTERN_FUNCS) for m in maps]
+    body = _fix_launch_func_args(body, prefix, arg_map)             # remap arg refs
+    bodies.append(body)
+    maps_all.extend(maps)
+
+# Then assemble: module { #maps... func.func @rms_gemms_rope(13 args) { bodies... return } }
+ +

The arg_map values are what enable data flow: {0:2, 1:3, 2:4} for Q GEMM means "the Q GEMM's slot 0 (its activation input) connects to the combined func's slot 2 (which is the RMSNorm output, normed)". Same DDR buffer, no host hop between RMSNorm and Q GEMM.

+ +

Stitching helpers in kernel_builder/stitching.py

+ + + + + + + + + +
FunctionWhat it does
_extract_between_func_and_return(mlir)Returns the body of the public func.func — everything between signature and return.
_extract_affine_maps(mlir)Returns the #map0 = ..., #map1 = ... declarations from the module header.
_extract_private_funcs(mlir)Returns func.func private declarations (e.g., external C++ kernel decls like @matvec_vectorized_bf16_bf16).
_rename_all(text, prefix)Renames every SSA value (%arg0%q_arg0), every affine map (#map0#q_map0), every symbol (@herd_0@q_herd_0) — but preserves external kernel function names.
_fix_launch_func_args(text, prefix, arg_map)After rename, fixes air.launch args(...) references to point at the COMBINED func's arg slots, not the per-sub-kernel ones.
_wrap_ir_in_launch(mlir)Some sub-builders (RMSNorm, eltwise add) emit a bare air.herd not wrapped in air.launch. This wraps it in air.launch { air.segment { herd } } — required because airrt-to-npu only sees segment_load ops.
+ +
+ What stitching saves vs. what it doesn't: stitching saves XRT dispatch overhead (one xrt.run vs N) and host orchestration (no host round-trip between launches). It does NOT save DDR traffic — intermediates still go through DDR; the launches just read/write that DDR via NPU DMA without involving the host. See ABLATION_STUDY Plan 0 (decode) for the measured contribution of pure merging — 1.71× alone, with another 1.60× from per-layer weight BOs (B7), totalling A→D = 2.75×. Plan 1 (prefill) shows the contribution shifts dramatically at prefill scale. +
+ +

Intra-ELF vs inter-ELF intermediate flow — what the production design actually does

+ +

This is the easiest place to get confused, so it's worth being explicit. The "stay on NPU" property of stitched intermediates applies only inside one ELF. As soon as you cross from one xrt.run() to another (e.g., rms_gemms_ropeflash_attno_ffn), the intermediates go through the host by default.

+ + + + + + + + + + + + + + + +
BoundaryHow intermediates flowCost per transferWhat "production" does
Intra-ELF
between sub-launches inside one merged ELF (e.g., RMSNorm → Q GEMM inside rms_gemms_rope)
NPU DMA reads from / writes to the same DDR-resident BO. Host is completely uninvolved during the xrt.run().~µs (NPU-internal DMA, dominated by L2/L1 fan-out)Always uses NPU-only flow. Marked via intermediate_indices so KernelCache neither host-writes on entry nor host-reads on exit.
Inter-ELF
between two separate xrt.run() calls (e.g., rms_gemms_ropeflash_attn)
By default: producer's output BO → sync(FROM_DEVICE) → host numpy view → next call's memcpy + sync(TO_DEVICE) into a SEPARATE BO. Two cache-coherent transfers + a memcpy per intermediate.~µs/MB at PCIe-equivalent bandwidth; per prefill layer the inter-ELF traffic adds up to ~40 MB round-tripProduction uses the host-broker pattern even though BO aliasing is technically possible (ablation Cell C demonstrates the alternative). See D2 for why production accepts this and what it would take to remove.
+ +

Concrete prefill numbers per pass (16 layers × 3 ELF dispatches per layer):

+ + + + + + + + + +
WherePer layerPer pass (16 layers)
Inside rms_gemms_rope (6 launches stitched)0 host transport (5 NPU-only handoffs)0
rms_gemms_ropeflash_attn (Q + K + V, host-broker)~12 MB ↓↑ (Q=8 MB, K=2 MB, V=2 MB)~192 MB
flash_attno_ffn (attn_out, host-broker)~8 MB ↓↑~128 MB
Inside o_ffn (8 launches stitched)0 host transport (7 NPU-only handoffs)0
K, V to KV cache (host transpose, B4)~4 MB ↓ each, plus CPU transpose~64 MB ↓ + ~16 ms CPU
Total inter-ELF host↔device traffic per pass~640 MB round-trip
+ +

At ~20 GB/s of host↔device bandwidth, ~640 MB ≈ ~32 ms ≈ 3% of the 1.13 s prefill. Decode is much smaller because per-token intermediates are KB-scale: ~10 KB per inter-ELF transfer × 33 NPU calls per token = a few MB, well under measurement noise. So inter-ELF host-broker is a real prefill cost, but tiny in decode.

+ +
+ So what's the design trade-off? Inter-ELF BO aliasing IS technically feasible (proven by ablation Cell C). Production chose the host-broker pattern for code simplicity — managing a cross-ELF BO graph + the MLIR shape conversions + lifetime tracking is non-trivial. The 3% prefill speedup is left on the table as known optimization headroom; see D2 in the Future work section. +
+ + +

B6. Gap #3 — Anatomy of one NPU call (BOs and host↔device data flow)

+ +

The problem. A stitched ELF (B5) hides 6-8 sub-launches behind one xrt.run(). But that single call still has to: get every input from host RAM into NPU-accessible DDR, hand the kernel handles to those buffers, run the kernel, and read outputs back. Done naively, every call would re-allocate buffers and re-upload weights — for a 14 MB wq tensor, that's ~5 ms of PCIe traffic per call, or ~80 ms × 16 layers = 1.3 s extra per prefill pass. The kernel finishes in tens of milliseconds; we cannot afford 5+ ms of host overhead per call.

+ +

This section explains what happens during ONE xrt.run() at the BO (Buffer Object) level — the unit of memory the NPU can read and write. Once you understand this anatomy, the per-layer BO trick in B7 (KernelCache) is straightforward.

+ +

What is a Buffer Object (BO)?

+ +

A BO is an XRT abstraction for a chunk of NPU-accessible memory. Physically it lives in DDR — the same RAM the host uses, but with a NPU-readable mapping. Created by xrt.bo(device, size_bytes, ...). Two operations matter:

+ + + + + + +
OpCostWhat it does
bo.map()~freeReturns a host pointer you can memcpy into. Host writes go to RAM directly.
bo.sync(TO_DEVICE)~µs/MB (cache flush)Flush host CPU caches so the NPU sees the up-to-date bytes when it DMAs from DDR.
bo.sync(FROM_DEVICE)~µs/MB (cache invalidate)Invalidate host CPU caches so the host sees the up-to-date bytes the NPU wrote.
+ +

The kernel doesn't get bytes — it gets a list of BOs (one per func.func argument), and the kernel's compiled code uses NPU DMA to stream chunks of those BOs into per-tile L1 / L2 SRAM as it runs.

+ +

The five steps of one xrt.run()

+ + + + + + + + +
StepWhat happensCost (typical)
1. Resolve XRT contextLook up the loaded xclbin for this kernel name; get the device handle and kernel symbol.~µs (cached)
2. Resolve BO listLook up or allocate the BO array for this bo_key. One BO per kernel argument.~µs (cached) or ~ms (first allocation)
3. Write inputsFor each non-static, non-intermediate input: memcpy(bo.map(), input_array) + bo.sync(TO_DEVICE). Static slots (weights) and intermediate slots (kernel-overwritten) are SKIPPED on every call after the first.~µs/MB per slot actually written
4. Submit kernelinvoker.run(*bos) — XRT enqueues the kernel and the call blocks until completion.~100 µs dispatch overhead + actual NPU compute time
5. Read outputsFor each slot in output_indices: bo.sync(FROM_DEVICE) + return a numpy view onto bo.map(). Other slots get a 0-length placeholder.~µs/MB per output
+ +

The three index sets — the per-call control knobs

+ +

Every load_and_run call (B7) accepts three optional sets that control which slots get host↔device data movement:

+ + + + + + +
SetMeaningEffect
output_indicesSlots the caller wants to read back to host (e.g., q_roped, k_roped).Triggers sync(FROM_DEVICE) for those slots only. Other slots get a 0-length placeholder in the return tuple.
static_input_indicesSlots holding weights/LUTs that are pre-loaded once and never change (e.g., wq, norm_w, RoPE LUT).Skipped by the host write loop on every call after the first. Combined with bo_key, lets per-layer weights persist on device across calls.
intermediate_indicesSlots the kernel will OVERWRITE — entry contents don't matter (e.g., the normed output of RMSNorm that the next launch reads).Skipped by the host write loop on every call after the first. Saves a memcpy + sync for buffers the host never needs to read or initialize.
+ +

These sets are what makes per-call cost go from "upload everything" (~ms) to "upload only the new activation" (~µs).

+ +

What ONE prefill kernel call actually does (concrete: rms_gemms_rope, layer 5, mid-prefill)

+ +
# Argument layout for rms_gemms_rope (13 slots, see B5/B7 for full list):
+#   0: x_in           ← layer activation, CHANGES every call
+#   1: norm_w         ← layer 5's RMSNorm weight, STATIC
+#   2: normed         ← intermediate (RMSNorm → GEMM)
+#   3: wq             ← layer 5's Q weight (~14 MB), STATIC
+#   4: q              ← intermediate (GEMM → RoPE)
+#   5: wk             ← layer 5's K weight (~3.5 MB), STATIC
+#   6: k              ← intermediate
+#   7: wv             ← layer 5's V weight (~3.5 MB), STATIC
+#   8: v              ← intermediate
+#   9: rope_lut_q     ← STATIC (LUT)
+#  10: rope_lut_k     ← STATIC
+#  11: q_roped        ← intermediate, but caller wants to READ it (output_index)
+#  12: k_roped        ← intermediate, but caller wants to READ it (output_index)
+
+cache.load_and_run(
+    "rms_gemms_rope", RGR_BACKEND,
+    x_in_bf16,                              # slot 0 (only this gets written)
+    lw.attn_norm,    np.zeros(...),       # slots 1, 2
+    lw.wq,           np.zeros(...),       # slots 3, 4
+    lw.wk,           np.zeros(...),       # slots 5, 6
+    lw.wv,           np.zeros(...),       # slots 7, 8
+    rope_lut_q, rope_lut_k,                 # slots 9, 10
+    np.zeros(...), np.zeros(...),       # slots 11, 12 (output buffers)
+    output_indices=[11, 12],
+    static_input_indices={1, 3, 5, 7, 9, 10},
+    intermediate_indices={2, 4, 6, 8, 11, 12},
+    bo_key=f"rms_gemms_rope_L5",         # this layer's BO set
+)
+ +

Per-call work: ONE memcpy (slot 0, ~8 KB) + ONE sync(TO_DEVICE) + run + TWO sync(FROM_DEVICE) (slots 11, 12). All 21 MB of weights stay resident on the NPU's BOs — the host doesn't touch them. Without static_input_indices + bo_key, the same call would memcpy and sync ~21 MB of weights every single time.

+ +
+ Bottom line on the per-call anatomy: the BO model lets you separate "what data does the NPU need" from "what does the host need to send THIS call". The three index sets (output / static / intermediate) plus the bo_key are the entire vocabulary for that separation. Whoever owns the load_and_run contract (B7) gets to make every call cheap — even the kernel-call burst inside a tight per-token decode loop. +
+ +

One important scope note: BOs are per-call, not shared across calls

+ +

Each load_and_run call resolves its own BO list via bo_key. Two different kernels (or two calls with different bo_keys) get independent BOs even if they conceptually pass the same intermediate. So:

+ +
    +
  • Inside one xrt.run(): the merged ELF's sub-launches all see the SAME BO list, so an intermediate written by sub-launch N is automatically visible to sub-launch N+1 (just two MLIR launches reading/writing the same arg slot). No host involvement.
  • +
  • Across two xrt.run() calls: kernel A's BOs and kernel B's BOs are different XRT objects in different _cached_bos entries. To get A's output into B's input you EITHER (1) sync to host and re-upload to B's BO (the default — host-broker), OR (2) explicitly alias B's input BO to point at A's output BO via a manual _share_bo trick (the ablation Cell C technique).
  • +
+ +

Production uses (1) for cross-kernel-group transfers — see the per-pass cost breakdown in B5 "Intra-ELF vs inter-ELF intermediate flow". Path (2) is the optimization tracked in D2 (Future work).

+ + +

B7. Gap #4 — KernelCache: compile-once, per-layer BO sets

+ +

The problem. Two costs would otherwise dominate every script start AND every kernel call:

+
    +
  1. Compile time. Compiling all 6 production ELFs takes ~3 minutes (B5 table). Recompiling on every python llama32_1b_inference.py run is unworkable.
  2. +
  3. BO management state. 16 layers × 6 ELFs × ~6 weight slots ≈ ~600 weight BOs holding ~1 GB of pre-uploaded weights need to stay alive and be addressable. Naively re-allocating per call would also dominate.
  4. +
+ +

KernelCache (in kernel_builder/cache.py:183) is the single class that solves both. It's the bridge between the per-call BO anatomy (B6) and the realities of running a 16-layer transformer.

+ +

Three layers of caching

+ + + + + + +
LayerWhat's cachedLifetimeKey
1. Disk artifactCompiled .elf + .insts.bin + kernel symbol namePersistent (until make clean)name (e.g. "rms_gemms_rope")
2. XRT contextLoaded XRT device + xclbin + kernel handleProcess lifetimename
3. Buffer ObjectsAllocated xrt.bo objects (one per kernel arg)Process lifetimebo_key (defaults to name; overridden per layer)
+ +

Layer 1 saves the 3-minute compile. Layer 2 saves the ~100 ms xclbin reload per kernel call. Layer 3 (combined with static_input_indices from B6) saves the per-call weight upload.

+ +

Class signature and state

+ +
class KernelCache:
+    def __init__(self, cache_dir=None, verbose=False, profiler=None):
+        self.cache_dir = Path(cache_dir)         # where .elf files persist on disk
+        self.profiler = profiler or Profiler()
+        self.artifacts = {}      # Layer 1: name → XRTCompileArtifact (paths + symbol)
+        self._loaded = {}        # Layer 2: name → (backend, invoker) — XRT handles
+        self._cached_bos = {}    # Layer 3: bo_key → list[xrt.bo] — per-session BOs
+ +

The two methods

+ +

compile_and_cache(name, mlir_module, backend_kwargs) — called ONCE per ELF

+ +
# kernel_builder/cache.py:251 (paraphrased)
+def compile_and_cache(self, name, mlir_module, backend_kwargs, output_binary_name="air"):
+    prepare_air_project()                          # clear air_project/ + compile .o files
+    backend = XRTBackend(**backend_kwargs)
+    artifact = backend.compile(mlir_module, ...)   # aircc → aiecc → .elf (the slow step)
+
+    cached_binary = self.cache_dir / f"{name}{ext}"
+    shutil.copy2(artifact.output_binary, cached_binary)
+
+    self.artifacts[name] = XRTCompileArtifact(str(cached_binary), artifact.kernel, cached_insts)
+    backend.unload()
+ +

Records name → cached_binary_path in self.artifacts. _save_manifest() writes the dict to cache_dir/manifest.json so a subsequent run with --run-only skips compilation entirely via load_manifest(). This is the difference between a 3-minute startup and a 5-second startup.

+ +

load_and_run(name, backend_kwargs, *inputs, ...) — called dozens of times per inference

+ +

This is the implementation of the per-NPU-call anatomy from B6. Annotated:

+ +
# kernel_builder/cache.py:294 (paraphrased — the contract)
+def load_and_run(self, name, backend_kwargs, *inputs,
+                 output_indices=None,
+                 static_input_indices=None,
+                 intermediate_indices=None,
+                 bo_key=None,
+                 naive=False):                   # naive=True is for the ablation study only
+
+    # 1. Lookup or load XRT context for this kernel name (Layer 2)
+    if name not in self._loaded:
+        backend = XRTBackend(**backend_kwargs)
+        backend.load(self.artifacts[name])
+        self._loaded[name] = (backend, backend.invoker)
+
+    # 2. Lookup or allocate BO list for this bo_key (Layer 3)
+    bo_key = bo_key or name             # default: shared BOs per kernel
+    if bo_key not in self._cached_bos:
+        bos = [allocate_bo(arr.nbytes) for arr in inputs]
+        self._cached_bos[bo_key] = bos
+        first_call = True
+    else:
+        bos = self._cached_bos[bo_key]
+        first_call = False
+
+    # 3. Write inputs (skipping static + intermediate after first call)
+    static = static_input_indices or set()
+    intermediate = intermediate_indices or set()
+    skip = (static | intermediate) if not first_call else set()
+
+    for i, arr in enumerate(inputs):
+        if i in skip:
+            continue                       # BO already has the right data
+        memcpy(bos[i].map(), arr)
+        bos[i].sync(TO_DEVICE)              # host → DDR
+
+    # 4. Run the kernel
+    invoker.run(*bos)
+
+    # 5. Read back only the requested outputs
+    output_indices = output_indices or [len(inputs) - 1]
+    results = []
+    for i, arr in enumerate(inputs):
+        if i in output_indices:
+            bos[i].sync(FROM_DEVICE)         # DDR → host
+            results.append(np_view(bos[i].map(), arr.shape, arr.dtype))
+        else:
+            results.append(np.empty(0, dtype=arr.dtype))   # placeholder
+    return tuple(results)
+ +
+ Two crucial properties of this contract: +
    +
  1. Return tuple has length len(inputs), not len(output_indices). Slots not in output_indices get an empty placeholder. Callers index by original arg position: out[2], out[14], etc.
  2. +
  3. static_input_indices and intermediate_indices only kick in after the first call for a given bo_key. The first call must write everything (the BOs have garbage). The pre-load pattern in prepare_runtime exists specifically to make the first call happen during init, not during timed inference.
  4. +
+
+ +

The bo_key trick — per-layer weight BOs

+ +

The single most consequential decision in the whole codebase. In plain language: give each of the 16 transformer layers its own independent set of NPU BOs, pre-load every layer's weights once at startup, then never re-upload weights again during inference.

+ +

Why the default is too slow

+ +

bo_key defaults to the kernel name (e.g. "rms_gemms_rope") — meaning ALL 16 layers share ONE set of BOs. With 6 weight slots in rms_gemms_rope totaling ~21 MB, the per-layer behavior would be:

+
    +
  • Layer 0: write layer-0 weights into BOs (~21 MB host→DDR), run kernel
  • +
  • Layer 1: BOs now hold layer-0 weights → must overwrite with layer-1 (~21 MB again), run
  • +
  • ... 16 layers total: ~336 MB of weight upload per prefill pass, just to feed the GEMMs
  • +
+ +

That's pure host overhead with zero NPU benefit. For decode, the per-token version of the same problem dominates the entire decode loop.

+ +

The trick: encode layer index in bo_key

+ +

Override bo_key to f"rms_gemms_rope_L{layer_idx}" so each layer gets its own slot in self._cached_bos. After the one-time preload, _cached_bos looks like this:

+ +
# Conceptual view of the cache state after preload
+self._cached_bos = {
+    "rms_gemms_rope_L0":  [bo_x, bo_norm0,  bo_normed, bo_wq0,  bo_q, ...],   # Layer 0's weights pre-uploaded
+    "rms_gemms_rope_L1":  [bo_x, bo_norm1,  bo_normed, bo_wq1,  bo_q, ...],   # Layer 1's weights pre-uploaded
+    "rms_gemms_rope_L2":  [bo_x, bo_norm2,  bo_normed, bo_wq2,  bo_q, ...],   # ...
+    ...
+    "rms_gemms_rope_L15": [bo_x, bo_norm15, bo_normed, bo_wq15, bo_q, ...],
+    "o_ffn_L0": [...],   # Same pattern for the other prefill ELF
+    ...
+}
+ +

16 layers × independent BO sets, each holding its own layer's weights resident on the NPU. Now the per-call code:

+ +
# preload_prefill_weights — runs ONCE before timing starts
+for layer_idx in range(16):
+    cache.load_and_run(
+        "rms_gemms_rope", RGR_BACKEND,
+        np.zeros(...),                                    # slot 0: x_in placeholder
+        weights.layers[layer_idx].attn_norm,                  # slot 1
+        np.zeros(...),                                    # slot 2
+        weights.layers[layer_idx].wq,                         # slot 3 (~14 MB)
+        ...                                                   # slots 4-12
+        bo_key=f"rms_gemms_rope_L{layer_idx}",             # UNIQUE per layer
+    )
+# After this loop: 16 separate BO sets are cached, each with its layer's weights uploaded.
+
+# During TIMED inference, exact same call shape but with the real activation in slot 0:
+for layer_idx in range(16):
+    out = cache.load_and_run(
+        "rms_gemms_rope", RGR_BACKEND,
+        x_bf16,                                               # slot 0: actual activation
+        ...                                                   # slots 1-12 (just placeholders, BOs already have weights)
+        static_input_indices={1, 3, 5, 7, 9, 10},  # skip weight write
+        intermediate_indices={2, 4, 6, 8, 11, 12},
+        bo_key=f"rms_gemms_rope_L{layer_idx}",             # picks layer's pre-loaded BOs
+    )
+ +

Now the timed call uploads ONLY the activation (slot 0, ~8 KB), even though there are 13 args. The 12 weight/intermediate slots are skipped because (static | intermediate) covers them and the BO list lookup hit the cached entry for that layer's bo_key. The ablation study (Plan 0, decode) measured this single optimization as the dominant contributor — 1.60× alone, the largest individual delta of all four gaps.

+ +

Two mechanisms work together: bo_key decides which set of BOs to look up; static_input_indices decides which slots in that set don't need to be re-written. Either alone wouldn't work — without per-layer keys, every layer overwrites every other layer's weights; without the static-skip flag, KernelCache would dutifully re-memcpy every weight slot every call even though the contents are already correct.

+ +

Trade-off: memory for speed

+ +

This is fundamentally a trade memory for speed design. Concrete numbers:

+ + + + + + +
CostDefault (shared bo_key)Per-layer bo_key
NPU-resident BO memory~120 MB (one set per ELF × 6 ELFs)~1.0 GB (16 layers × 6 ELFs)
Host→device upload per prefill pass~336 MB (16 × 21 MB rewrites)~128 KB (just activations)
One-time preload cost0~200-300 ms (once at startup)
+ +

~1 GB of pinned BO memory is acceptable for a 1.24 B-parameter model on a system with 16+ GB of RAM. If memory were tight, you could fall back to shared bo_key and accept the per-call upload cost — the contract would still work, just slower.

+ +

Subtle point: aren't CPU and NPU sharing the same DDR?

+ +

Yes — NPU2 (Strix) is a unified-memory architecture, so the NPU and CPU share the same physical DDR. So why is there still a memcpy + memory duplication?

+ +

Because "shared DDR" doesn't mean "shared allocation". A normal numpy array and an XRT BO live in the same DDR but in different memory regions with different attributes:

+ + + + + +
Buffer kindAllocatorAttributesWho can read it?
numpy weight arrayPython / glibc mallocPageable, virtual, CPU-cachedCPU only
XRT Buffer Objectxrt.bo(device, size)Physically contiguous, pinned (non-pageable), specific cache attributes, mapped into BOTH CPU and NPU virtual address spacesCPU and NPU
+ +

The NPU's DMA engine can ONLY access physically-contiguous, pinned memory — it can't read a random pageable numpy buffer (which is virtually contiguous but physically scattered, and may be swapped out at any moment). So a BO is a special chunk of DDR, requested separately and held alive for the BO's lifetime.

+ +

That means the data flow is genuinely:

+
    +
  1. Weight loaded by HuggingFace → numpy array in pageable RAM (one copy, ~14 MB for wq)
  2. +
  3. Preload calls memcpy(bo.map(), weight_array) → physical byte copy into the BO's pinned region (~3 ms for 14 MB)
  4. +
  5. bo.sync(TO_DEVICE) → flushes CPU L1/L2/L3 caches so the NPU's DMA reads the up-to-date DDR contents (NOT a copy — pure cache management)
  6. +
  7. NPU runs; reads the BO via DMA; writes outputs back
  8. +
  9. For outputs: bo.sync(FROM_DEVICE) → invalidates CPU caches so a subsequent host read sees what the NPU wrote
  10. +
+ +

So yes — even with shared DDR, the production codebase keeps two physical copies of each weight (the numpy array + the BO), and the preload step really does memcpy them. ~1 GB extra memory + ~200-300 ms one-time preload is the price.

+ +

Could it be zero-copy? In principle yes — you could allocate the BO first and then construct a numpy view via np.frombuffer(bo.map(), ...), so the safetensors loader writes directly into the pinned region. The codebase doesn't do this for two reasons:

+
    +
  • The CPU-side weight pre-transpose (B4 mismatch #1) creates new arrays anyway.reshape().T.ascontiguousarray() always materializes a fresh buffer, so the transposed result has to be copied into the BO regardless of how the original was allocated.
  • +
  • Engineering cost vs. payoff — making the weight loader BO-aware would require a custom allocator path through HuggingFace + safetensors, significant complexity for ~200-300 ms savings on a one-time startup cost that's not in the inference critical path.
  • +
+ +

So the codebase trades the simplicity of standard numpy for a small one-time memory + memcpy cost. "Unified memory" eliminates cross-PCIe DMA (which discrete GPUs suffer); it doesn't eliminate the pinned-vs-pageable distinction or the cache-coherency flush.

+ +
+ Bottom line on KernelCache: three caches with three lifetimes (disk / process / process), one method (load_and_run) implementing the B6 anatomy with the index-set contract, and one trick (bo_key=f"name_L{layer_idx}") that turns "16 layers × ~50 MB of weights to upload per call" into "0 weight uploads per call after preload". The trade is ~1 GB of pinned BO memory for ~hundreds of ms saved per inference. Without this class, the codebase wouldn't be 1.27 s prefill — it would be tens of seconds. +
+ + +

B8. Prefill in NPU detail — putting all four gaps together

+ +

Per-layer kernel sequence — 3 NPU calls

+ +
+

Layer N (prefill)

+
+ NPU 1 +
rms_gemms_rope.elf — 6 stitched launches: RMSNorm(x) → Q/K/V projections → RoPE on Q and K. Reads x_in (seq, 2048); writes q_roped (seq, 2048), k_roped (seq, 512), v (seq, 512). Realizes Part A2 ops 1-6.
+
cache.load_and_run("rms_gemms_rope", ...)
+
+
+ NPU 2 +
flash_attn.elf — 1 launch: causal GQA flash attention. Reads q_roped, k_roped, v; writes attn_out (seq, 2048). Also extracts k_cache, v_cache for decode. Realizes Part A2 op 7.
+
cache.load_and_run("flash_attn", ...)
+
+
+ NPU 3 +
o_ffn.elf — 8 stitched launches: O projection → residual add → RMSNorm → Gate/Up GEMMs → SwiGLU → Down GEMM → second residual add. Reads attn_out, x_residual; writes the layer output. Realizes Part A2 ops 8-15.
+
cache.load_and_run("o_ffn", ...)
+
+
+ +

After all 16 layers: CPU RMSNorm on the last token's hidden state (Part A5), then lm_head_gemv.elf (8 partitions, 1 NPU call) → argmax → first generated token.

+ +

Tile usage: rms_gemms_rope's GEMMs use the full [8,4] = 32-tile array; its RMSNorm + RoPE use [8,1] = 8 tiles. flash_attn uses a multi-segment cascade ~16-24 tiles. o_ffn's GEMMs use [8,4] = 32 tiles; its add/RMSNorm/SwiGLU use [8,1] = 8 tiles. See B2.8 tile-mapping summary for the full table.

+ +

Code walk: run_npu_prefill

+ +
# llama32_1b_inference.py:341 — main prefill entry
+def run_npu_prefill(token_ids, weights, config, prefill_cache, decode_cache,
+                    rope_lut_bf16, max_seq, tokenizer, ...):
+    seq_len = len(token_ids)                # 2048
+
+    # Pre-allocate KV cache (16 layers × 8 KV heads × 2048 × 64), see Part A4
+    k_cache = np.zeros((config.n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16)
+    v_cache = np.zeros((config.n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16)
+
+    # Token embedding (host-side numpy lookup)
+    x_bf16 = weights.embed_table[token_ids].astype(bfloat16)
+
+    # --- TIMED SECTION START ---
+    for layer_idx in range(config.n_layers):           # 16 layers
+        x_bf16, intermediates = run_transformer_block(
+            x_bf16, weights.layers[layer_idx], rope_lut_bf16,
+            config, prefill_cache, layer_idx=layer_idx, ...
+        )
+        # Extract KV cache from this layer's intermediates (see Part A4)
+        k_cache[layer_idx, :, :seq_len, :] = intermediates["k_roped"]...
+        v_cache[layer_idx, :, :seq_len, :] = intermediates["v"]...
+
+    # Find last real token (see Part A5 padding)
+    prompt_len = len([t for t in token_ids if t != tokenizer.eos_token_id])
+    pred_pos = prompt_len - 1
+
+    # Final RMSNorm + LM Head — only the last real-token row
+    last_normed = _rms_norm(x_bf16[pred_pos:pred_pos+1], weights.final_norm)
+
+    # NPU LM Head GEMV — reuse decode-cache 8-partition GEMV ELF
+    results = decode_cache.load_and_run("lm_head_gemv", LM_GEMV_BACKEND, ...)
+    logits_row = np.concatenate(results, axis=0)[:vocab_size]
+    prefill_token = int(np.argmax(logits_row))
+
+    return prefill_token, k_cache, v_cache, prompt_len
+ +

How weights flow into the kernel: prefill preload

+ +

Before any timing starts, preload_prefill_weights writes ALL 16 layers' weights into per-layer NPU BOs:

+ +
# llama32_1b_prefill.py — preload_prefill_weights (paraphrased)
+def preload_prefill_weights(weights, config, cache, seq_len, rope_lut):
+    for layer_idx in range(config.n_layers):              # 16 layers
+        lw = weights.layers[layer_idx]
+        cache.load_and_run(
+            "rms_gemms_rope", RMS_GEMMS_ROPE_BACKEND,
+            np.zeros((seq_len, emb_dim), dtype=bfloat16),  # slot 0: x_in (placeholder)
+            lw.attn_norm.astype(bfloat16),                 # slot 1: norm_w (STATIC)
+            np.zeros((seq_len, emb_dim), dtype=bfloat16),  # slot 2: normed (intermediate)
+            lw.wq.astype(bfloat16),                        # slot 3: wq (STATIC)
+            # ... 9 more args (intermediates + weights + LUTs)
+            output_indices=[11, 12],                   # read q_roped, k_roped back
+            static_input_indices={1, 3, 5, 7, 9, 10},  # weights/LUTs: written once
+            intermediate_indices={2, 4, 6, 8, 11, 12},  # overwritten by kernel
+            bo_key=f"rms_gemms_rope_L{layer_idx}",        # per-layer BO set
+        )
+        # Same pattern for o_ffn ELF — 16 different BO sets, one per layer
+ +
+ The bo_key trick (this is what "per-layer weight BOs" means): KernelCache caches BO objects keyed by bo_key. By using f"rms_gemms_rope_L{layer_idx}", each layer gets its OWN set of NPU BOs. The weights for layer 5 stay in layer 5's BOs and are never overwritten by layer 6. During inference, the timed call uses the same bo_key, so the per-layer weights are already on device — only the x_in activation needs to be host-uploaded. +
+ + +

B9. Decode in NPU detail — putting it all together for per-token generation

+ +

Per-token, per-layer kernel sequence

+ +

Decode works on one token at a time. Per token, per layer, it makes 3 calls (2 NPU + 1 CPU):

+ +
+

Token T, Layer N (decode)

+
+ NPU 1 +
rms_gemv_rope.elf — 6 stitched launches: RMSNorm(x_decode) → Q/K/V GEMVs (each W·x for the single token) → RoPE Q/K. Reads single-token x_in (2048,); writes single-token q_roped (2048,), k_roped (512,), v (512,).
+
cache.load_and_run("rms_gemv_rope", ...)
+
+
+ CPU +
decode_attention_cpu — Single-query GQA attention against the cumulative KV cache (positions 0 to current_pos). Updates KV cache with new k_roped, v. Why CPU? At head_dim=64 the NPU FA path has overhead; CPU is cheap for single-query.
+
llama32_1b_decode.py:96
+
+
+ NPU 2 +
o_gemv_ffn.elf — 8 stitched launches: O GEMV → residual add → RMSNorm → Gate/Up GEMVs → SwiGLU → Down GEMV → second residual add. Output feeds next layer's x_decode.
+
cache.load_and_run("o_gemv_ffn", ...)
+
+
+ +

After all 16 layers (per token): CPU RMSNorm on the resulting hidden state, then lm_head_gemv.elf → argmax → next token.

+ +

Tile usage: EVERY decode kernel uses ≤ 8 tiles (one column of the 8×4 array): the GEMVs are [8,1], RMSNorm + SwiGLU + add are [8,1], and RoPE drops to [1,1] (only one row to rotate). The decode path leaves at least 24/32 = 75% of the compute array idle on every NPU dispatch — one reason decode is dispatch-overhead-bound (see ablation Plan 0: A→D = 2.75× from removing dispatch overhead, not from doing more compute).

+ +

Code walk: the decode loop

+ +
# llama32_1b_inference.py:585 — the decode loop inside generate()
+for token_idx in range(n_tokens):
+    t_token_start = time.perf_counter()
+
+    x = x_decode.copy()                              # single-token activation (emb_dim,)
+    for layer_idx in range(config.n_layers):       # 16 layers
+        x = run_decode_block(
+            x, weights.layers[layer_idx], decode_cache, config,
+            k_cache[layer_idx], v_cache[layer_idx],     # growing each iter
+            current_pos, rope_lut_bf16,
+        )
+
+    # Final RMSNorm (CPU, <1ms for 2048 elements)
+    x_normed = rms_norm(x.astype(np.float32).reshape(1, emb_dim),
+                       weights.final_norm.astype(np.float32))
+
+    # LM Head — NPU 8-partition GEMV (single XRT call, 8 launches in one ELF)
+    x_lm = x_normed.flatten().astype(bfloat16)
+    lm_inputs = [x_lm]                                # slot 0: shared input
+    for p in range(_LM_N_PARTITIONS):                # 8 partitions
+        lm_inputs.append(weights._lm_weight_parts_gemv[p])  # weight
+        lm_inputs.append(np.zeros(_LM_N_PART, dtype=bfloat16))  # output buffer
+
+    lm_results = decode_cache.load_and_run(
+        "lm_head_gemv", LM_GEMV_BACKEND, *lm_inputs,
+        output_indices=[2 + 2*p for p in range(8)],   # 8 outputs
+        static_input_indices={1 + 2*p for p in range(8)},  # weights static
+        intermediate_indices={2 + 2*p for p in range(8)},  # skip output writes
+    )
+
+    # Concatenate 8 partition outputs into one logits array, argmax
+    logits = _assemble_logits(lm_results, vocab_size)
+    next_token = int(np.argmax(logits[0]))
+    generated_tokens.append(next_token)
+    x_decode = weights.embed_table[next_token].astype(bfloat16)
+    current_pos += 1
+
+    if next_token in (tokenizer.eos_token_id, 128009):  # <|eot_id|>
+        break
+ +
+ Why decode uses CPU attention instead of NPU FA: the production NPU FlashAttention kernel was designed for prefill's seq=2048 batch and has overhead for single-query workloads at head_dim=64. CPU attention is faster for the small single-query case. This is documented in profile.md as a known limitation; an NPU decode FA was added for the larger Llama-3B variant (head_dim=128) but isn't used here. +
+ + + +

B10. Code map — where everything lives

+ +

Reference section: a top-down map of every file involved in the production runtime, useful for grepping or for finding the right entry point.

+ +

Top-level Python files programming_examples/llama32_1b/

+ + + + + + + + + +
FileLinesPurpose
llama32_1b_inference.py975Main entry point. Unified prefill + decode pipeline. main() at the bottom.
llama32_1b_prefill.py514Standalone prefill (with profiler report). compile_all_kernels, run_transformer_block, preload_prefill_weights.
llama32_1b_decode.py286Standalone decode. compile_decode_kernels, run_decode_block, decode_attention_cpu.
llama32_1b_weights.py522HuggingFace safetensors loader. LlamaConfig, LayerWeights, LlamaWeights, load_weights, synthetic_weights, generate_rope_lut.
llama32_1b_cpu_helpers.py~90Small NumPy helpers shared by production + verify: rms_norm (LM-head GEMV final norm), attention_reference (prefill cpu_attn=True fallback), softmax (used by attention_reference). The file used to host a full F32 forward pass + standalone --verify CLI; both became redundant once the verify subsystem started comparing directly against HF transformers bf16.
verify/End-to-end verification subsystem. verify_runner.py orchestrates the top-k token gate (make verify) and the diagnosis lens (make diagnosis). See VERIFICATION.html.
Makefile112Convenience targets: compile, run, profile, chat, verify, diagnosis, clean.
+ +

Shared infrastructure kernel_builder/

+ + + + + + + + +
FileLinesPurpose
cache.py453The KernelCache class. Manages compile, cache, load, run, and BO reuse for all kernels. See B7.
stitching.py206Text-based MLIR stitching utilities for assembling multi-launch ELFs. See B5.
gemm_builder.py137Wraps the upstream matrix_multiplication/bf16/run.py:build_module + applies an additional MLIR transform IR script for prefill GEMMs. See B2.2.
external_kernels.py180Compiles all C++ .o kernel files via Peano (rope, silu_and_mul, mv, mv_k8192, attn).
backend_presets.py65All *_BACKEND kwarg dicts (RGR_BACKEND, OGF_BACKEND, etc.) — XRTBackend init params per kernel.
rope_halfsplit.cc~100Custom RoPE C++ kernel matching HuggingFace's half-split convention.
+ +

Multi-launch builders multi_launch_builder/

+ + + + + + + +
FilePhaseLaunchesBuilds
rms_gemms_rope_multi.pyPrefill6RMSNorm + Q/K/V GEMM + RoPE Q + RoPE K (Part A2 ops 1-6)
o_ffn_multi.pyPrefill8O GEMM + Add + RMSNorm + Gate/Up GEMM + SiLU×mul + Down GEMM + Add (Part A2 ops 8-15)
rms_gemv_rope_multi.pyDecode6RMSNorm(1D) + Q/K/V GEMV + RoPE Q + RoPE K — single-token version
o_gemv_ffn_multi.pyDecode8GEMV variants of o_ffn — single-token version
lm_head_gemv_multi.pyBoth88-partition vocab GEMV (16384 outputs each)
+ +

Other directories

+ + + + + + +
PathPurpose
standalone_kernels/K1..K10/Individual chunk-level kernels for debug; not used by production runtime.
ffn_swiglu/silu_and_mul.ccCustom SwiGLU C++ kernel.
docs/Documentation: profile.md, explain.md, usage.md, issues.md.
ablation/The 4-cell ablation study — decode (top-level pilot + decode/ full per-token) and prefill (prefill/). Comprehensive walkthrough in ABLATION_STUDY.html.
+ +

How model concepts (Part A) map to NPU code (Part B)

+ + + + + + + + + + + + + +
Model conceptNPU realizationFile:Function
One transformer block (14 ops)3 NPU calls per layer (rms_gemms_rope + flash_attn + o_ffn)llama32_1b_prefill.py:run_transformer_block
14 ops within a blockStitched into 6+1+8 = 15 sub-launches across 3 ELFs (B5)The multi_launch_builder/*_multi.py files
Token embedding lookupnumpy fancy-indexing on hostllama32_1b_inference.py:373 (embed_table[token_ids])
Final RMSNormHost CPU (1 row only — only the prediction row matters)llama32_1b_inference.py:425-430
LM HeadNPU 8-partition GEMV (1 ELF, 8 launches in 1 xrt.run)multi_launch_builder/lm_head_gemv_multi.py
K cache write (prefill, with transpose)numpy slice assign on host (B4 layout mismatch #2)llama32_1b_inference.py:401
K cache write (decode)numpy slice assign on host inside run_decode_blockllama32_1b_decode.py
Decode attentionCPU (numpy) — single-query GQA against the cache slicellama32_1b_decode.py:96 decode_attention_cpu
Prefill attentionNPU FlashAttention causal GQA (its own ELF, see B5)flash_attention/kernel_fusion_based/attn_npu2_seqfirst.py
Decode GEMV pre-transposed weightsOne-time CPU pre-transpose at startup (B4 layout mismatch #1)llama32_1b_inference.py:171-197
+ + + +

Part C — Verification

+ +

The verification subsystem lives in its own subdirectory (verify/) and is documented end-to-end in VERIFICATION.html. This part is a one-page pointer; treat the companion doc as the source of truth.

+ +

What runs

+ +

Two entry points, both routed through the parent Makefile and both comparing against HuggingFace transformers in bf16 (same dtype as the NPU — fair fight):

+ + + + + +
TargetWhat it doesPass/fail?
make verify [MODEL=base|instruct]8 prompts × 32 greedy-decoded tokens. At each step both runners' chosen tokens must appear in the OTHER side's top-5 (k=5). Mirrors vLLM's check_logprobs_close. ~4 min.Yes. Exits 1 on any FAIL.
make diagnosis [MODEL=...] [PROMPT="..."]Single prompt, prefill only. Per-layer ffn_out cosine + max_abs (NPU vs HF bf16) for all 16 layers. ~3 min.Informational only. Read the table by hand to localize a regression flagged by verify.
+ +

How it stays in sync with production

+ +

The verify NPU runner (verify/runners/npu_runner.py) is a thin adapter — it imports and invokes the same prepare_runtime, run_npu_prefill, and run_npu_decode_step functions that make run calls. Any change to the production prefill/decode path is automatically tracked by make verify; there is no parallel maintenance.

+ +

Why discrete top-k inclusion (and not continuous correlation)

+ +

bf16 ULP noise routinely flips per-step top-1 between two mathematically equivalent implementations, so a corr > 0.99-style threshold either trips on noise or sits so loose that real regressions slip through. Discrete top-k inclusion is the escape: bf16 noise can flip top-1 but rarely displaces a token from the top-5, so the gate distinguishes "drift" from "implementation bug" cleanly. See VERIFICATION.html §3 for the full argument.

+ +

CI

+ +

The LIT test run_npu2_verify.lit runs make verify MODEL=instruct on the NPU2 self-hosted runner and FileCheck-asserts [verify] PASS. REQUIRES: ryzen_ai_npu2, peano, hf_token — local runs without an HF token skip cleanly.

+ +

Part D — Future work

+ +

A running list of optimizations and design changes that the current production codebase does NOT do, but that we have identified as worth pursuing — typically because they unlock a new capability (larger models, lower latency) or remove a known scalability bottleneck. Each entry captures the motivation, current behavior, proposed change, and rough impact estimate, so a future contributor can pick one up without re-deriving the context.

+ +

Format: impact tag (how much it matters), effort tag (rough engineering size), status tag (idea / scoped / in-progress). This section grows over time as new ideas emerge.

+ + +

D1. Zero-copy weight loading — eliminate CPU↔BO duplication

+ +
+

Make BO the single physical storage for weights (no second numpy copy)

+ Impact: HIGH (scaling to larger models) + Effort: MEDIUM-LARGE + Status: identified, not scoped + +

Why it matters

+ +

The current preload pipeline keeps two or three physical copies of each weight tensor in DDR (see B7 "Subtle point: aren't CPU and NPU sharing the same DDR?"):

+
    +
  • The original numpy array from HuggingFace safetensors (~14 MB for wq)
  • +
  • The transposed bf16 copy _wq_t created by the GEMV pre-transpose step (B4 layout mismatch #1, ~14 MB)
  • +
  • The XRT BO that the NPU actually reads (~14 MB)
  • +
+ +

For Llama-3.2-1B (~2.5 GB of bf16 weights), the per-layer BO trick (~1 GB resident) plus duplicated numpy/transposed copies puts total memory at ~5-6 GB. This is fine on a 16-32 GB host, but it does NOT scale:

+ + + + + + + +
ModelBF16 weightsEstimated total RAM with current scheme (rough)
Llama-3.2-1B (current)~2.5 GB~5-6 GB ✓ fits
Llama-3.2-3B~6.4 GB~13-15 GB (tight on 16 GB host)
Llama-3.1-8B~16 GB~32-40 GB (won't fit on most consumer NPU2 systems)
Llama-3.3-70B~140 GB— (impossible without zero-copy)
+ +

Memory will become the bottleneck once we move beyond 1-3 B-parameter models. Solving this is a prerequisite for larger model deployment, not a nice-to-have.

+ +

Current behavior (what we want to change)

+ +

From preload_prefill_weights via cache.load_and_run with static_input_indices:

+ +
# Three physical copies in DDR for each weight tensor:
+weights.layers[5].wq                      # 1) HuggingFace numpy, ~14 MB pageable
+lw._wq_t = np.ascontiguousarray(           # 2) transposed numpy, ~14 MB pageable
+    lw.wq.astype(bfloat16)
+        .reshape(emb_dim, emb_dim).T
+)
+memcpy(bo.map(), lw._wq_t)              # 3) XRT BO, ~14 MB pinned
+bo.sync(TO_DEVICE)
+ +

Proposed change

+ +

Use np.frombuffer(bo.map(), ...) to make the BO the only physical storage; numpy is just a view onto it:

+ +
# Allocate the destination BO first
+bo = xrt.bo(device, weight_size_bytes)
+
+# Construct a numpy view that points INTO the BO's pinned region
+weight_view = np.frombuffer(
+    bo.map(), dtype=bfloat16, count=weight_n_elements
+).reshape(out_dim, in_dim)
+
+# safetensors loader writes directly into the BO via the numpy view
+load_safetensors_layer_into(weight_view, layer_idx, "wq")
+bo.sync(TO_DEVICE)
+# NO memcpy. NO second copy. The BO IS the weight storage.
+ +

Engineering cost (why it hasn't been done yet)

+ +
    +
  1. safetensors loader needs a "load into existing buffer" API. Today the loader returns a fresh numpy array — caller can't supply the destination buffer. This requires either a custom safetensors reader (~200 LOC) or a pre-allocate-then-copy step that defeats the purpose.
  2. +
  3. Transpose problem. The B4 weight pre-transpose materializes a NEW array (.T.ascontiguousarray()). For zero-copy to work end-to-end, the transposed result must land directly in the destination BO too. Either: +
      +
    • Allocate two BOs per weight (original + transposed), let the transpose write into BO #2, then free BO #1 — but at this point you've used 2× BO memory transiently and have a refcount-management problem
    • +
    • Have the safetensors loader perform the transpose during load (read in transposed order from the file format) — requires understanding safetensors' chunk layout
    • +
    +
  4. +
  5. Verify subsystem dependency. verify/runners/npu_runner.py calls prepare_runtime + run_npu_prefill + run_npu_decode_step with the production LlamaWeights object — the same one this BO-aliasing scheme would mutate. If a weight tensor switches from a numpy array to a bf16 BO view mid-call, both verify (HF-bf16 reference, dtype-agnostic) and diagnosis (per-layer ffn_out cosine) need to keep producing the same numbers. Audit the Hf-comparison path before flipping the storage.
  6. +
  7. BO lifetime + GC. If a numpy view holds a reference to bo.map() but the bo Python object is GC'd, the view becomes a dangling pointer. Need explicit owner-tracking (e.g. attach the BO as an attribute of the numpy view, or maintain a parallel _bo_keepalive list).
  8. +
  9. Multi-consumer weights. weights.lm_head is sliced into 8 partitions for the LM Head GEMV. If the source is a BO view, all 8 partition views must coexist without anyone freeing the underlying BO.
  10. +
+ +

Estimated impact

+ + + + + + + + +
SavesAmount
One-time preload memcpy time~200-300 ms (currently amortized; not in critical path)
Pageable RAM (numpy original)~2.5 GB for 1B model, scales with model size
Pageable RAM (transposed copy)~1.3 GB extra (decode-side weights only — prefill GEMM uses original layout)
Total RAM saving for 1B~3.8 GB → roughly halves total memory footprint
UnlocksLlama-8B+ on consumer NPU2 hardware that today can't fit those models
+ +

Suggested approach when scoped

+ +
    +
  1. Start with a tiny PoC: pick ONE weight tensor (e.g., layer 0's wq), implement the BO-allocate-then-numpy-view path, confirm bit-exact outputs vs. current path on the verify gate.
  2. +
  3. Extend to all weights for ONE layer; profile real RAM footprint to confirm savings.
  4. +
  5. Solve the transpose problem (likely: load safetensors in transposed order rather than transpose after).
  6. +
  7. Roll out across all 16 layers; deprecate the numpy weight reference path; add a flag to fall back for verify.
  8. +
  9. Validate on 3B model as a stretch test before committing to 8B-class ambitions.
  10. +
+ +

Background discussion: the trade-off and the pinned-vs-pageable subtlety are documented in B7. The reason "shared DDR" doesn't make this problem go away on its own is also there.

+
+ + +

D2. Cross-ELF BO aliasing — eliminate inter-ELF host round-trips

+ +
+

Wire producer-output BOs directly to consumer-input BOs across separate xrt.run() calls

+ Impact: LOW-MEDIUM (~3% prefill, ~0% decode) + Effort: MEDIUM + Status: validated by ablation Cell C, not in production + +

Why it matters

+ +

As documented in B5 "Intra-ELF vs inter-ELF intermediate flow", production currently routes intermediates between separate ELFs (e.g. rms_gemms_ropeflash_attno_ffn) through the host: producer output is sync'd to host, then memcpy'd + sync'd back into the consumer's input BO. This adds up to ~640 MB host↔device round-trip per prefill pass — about 3% of the 1.13 s prefill wall time. Decode is unaffected (intermediates are KB-scale).

+ +

Multi-launch ELF stitching (B5 / Gap #2) eliminates this for sub-launches inside one ELF, but FlashAttention is un-mergeable into the surrounding kernel-groups (compiler pass complexity), so prefill stays as 3 separate ELFs per layer with host-broker round-trips between them. Cross-ELF BO aliasing is the technique that recovers that 3% without merging the ELFs.

+ +

Current behavior (what we want to change)

+ +

From cells/multi_layer.py / production prefill loop:

+ +
for L in range(16):
+    rg_out = run_rms_gemms_rope(cache, layer_in, layer_idx=L)
+    # rg_out["q_roped"] is a numpy view onto host RAM — sync(FROM_DEVICE) just happened
+
+    q_roped_2d = rg_out["q_roped"].reshape(seq, emb)         # free metadata reshape
+    k_roped_2d = rg_out["k_roped"].reshape(seq, kv)
+    v_2d = rg_out["v"].reshape(seq, kv)
+
+    fa_out = run_flash_attn(cache, q_roped_2d, k_roped_2d, v_2d, layer_idx=L)
+    # ↑ entering FA: memcpy host numpy → FA's BO + sync(TO_DEVICE)
+    #   Same data that just left rms_gemms_rope's output BO is now duplicated in FA's input BO
+ +

Proposed change — alias the BOs explicitly

+ +

Use the same _share_bo helper that ablation Cell C already validated:

+ +
# During preload, after both ELFs have allocated their BOs:
+_share_bo(cache,
+    f"rms_gemms_rope_L{L}", slot=11,        # producer's q_roped output BO
+    f"flash_attn_L{L}",       slot=0,         # consumer's Q input BO — now points at same DDR
+)
+_share_bo(cache, f"rms_gemms_rope_L{L}", 12, f"flash_attn_L{L}", 1)   # K
+_share_bo(cache, f"rms_gemms_rope_L{L}",  8, f"flash_attn_L{L}", 2)   # V
+_share_bo(cache, f"flash_attn_L{L}", 3, f"o_ffn_L{L}", 0)               # attn_out
+
+# During timed inference, mark these slots intermediate so KernelCache skips host I/O:
+fa_out = cache.load_and_run("flash_attn", FA_BACKEND, ...,
+    intermediate_indices={0, 1, 2, 3},          # Q, K, V (in), attn_out (out)
+    # NO output_indices for attn_out — it stays on device for o_ffn
+)
+ +

How much can actually be saved

+ +

Not all inter-ELF transfers can be 100% eliminated, because the host still needs SOME of them for non-NPU work:

+ + + + + + + + +
TransferCan fully alias?Reason
Q (rms_gemms_rope → FA)✅ YesHost never touches Q during prefill
K (rms_gemms_rope → FA)⚠️ PartialFA reads it, AND host needs to sync(FROM_DEVICE) + transpose to write KV cache (B4 mismatch #2). Save the host→FA write only
V (rms_gemms_rope → FA)⚠️ PartialSame as K
attn_out (FA → o_ffn)✅ YesHost never touches attn_out
o_ffn output → next layer's rms_gemms_rope's x_in✅ YesPure layer-to-layer activation pass
+ +

Best-case saving: drop ~640 MB / pass to ~150 MB / pass (KV cache extraction still needs the device→host read). Wall-time saving: from ~3% to ~0.7% — recovering ~25 ms of the prefill.

+ +

Engineering cost (why it hasn't been done yet)

+ +
    +
  1. Manual BO graph maintenance. Every cross-ELF data flow requires an explicit _share_bo wiring call during preload. For 16 layers × 4-5 cross-ELF edges, that's ~70 wiring lines that must stay synchronized with the kernel-group load_and_run argument layouts. If a layout changes, every aliasing line has to be audited.
  2. +
  3. Shape mismatch between producer and consumer. rms_gemms_rope emits 1D flat arrays (q_roped[seq*emb]); FA expects 2D (seq, emb). Today the host does the metadata-only reshape between them. With aliasing the host is no longer in the loop — the shape conversion has to happen on the MLIR side via memref.expand_shape at the FA entry, which means modifying FA's kernel signature or wrapping its launch.
  4. +
  5. KV cache write coordination. K and V are needed by both the FA (consumer) and the host (KV cache writer). Aliasing means both read from the same BO. The host's sync(FROM_DEVICE) must happen at the right moment — after the producer has finished writing but before/during FA reading. Currently the host-broker pattern enforces this naturally; with aliasing it needs explicit ordering.
  6. +
  7. FA's internal BO reuse. FlashAttention is un-mergeable partly because of how it uses air.channels and many internal sub-buffers. Aliasing its input BOs needs to verify that FA doesn't internally reuse those slots in a way that would corrupt the producer's data mid-execution.
  8. +
+ +

Estimated impact

+ + + + + + + +
SavesAmount
Inter-ELF host↔device round-trip per prefill pass~640 MB → ~150 MB (factor 4× reduction)
Wall time per prefill pass~25 ms (~2.3% of 1.13 s)
Wall time per decode token< 1 ms (negligible — intermediates are KB-scale in decode)
Doesn't change anything forDecode performance, model size scaling, code complexity tradeoffs
+ +

Suggested approach when scoped

+ +
    +
  1. Start with the easiest edge: alias attn_out (FA → o_ffn). It has no host consumer, so it's a clean win.
  2. +
  3. Validate output vs. the production path on make verify (top-k token gate) and inspect make diagnosis for unexpected per-layer drift.
  4. +
  5. Profile to confirm the predicted ~5-10 ms / pass saving is real.
  6. +
  7. Add Q aliasing next (also no host consumer).
  8. +
  9. Tackle K/V partial aliasing last — needs the host-readout coordination.
  10. +
  11. Consider whether the engineering cost is worth ~25 ms / pass at this point. If decode-side or memory-side optimizations (D1) become the priority, this can be deferred indefinitely.
  12. +
+ +

Background: ablation Cell C already implements this pattern WITHIN one kernel-group (between separate xrt.run()s of the un-merged baseline). The same _share_bo mechanism would extend to ACROSS kernel-groups in production.

+
+ + +

D3. CI: wire up HF_TOKEN so make verify actually runs in CI

+ +
+

The verify gate is shipped but not enforced by CI yet

+ Impact: MEDIUM (CI cannot catch verify regressions today) + Effort: SMALL + Status: identified, not done + +

Why it matters

+ +

The whole point of refactoring NpuRunner into a thin adapter over the production prefill/decode functions (VERIFICATION.html) is that any change to production code is automatically tracked by make verify — no parallel maintenance. But that guarantee only pays off if CI actually runs make verify on every PR. Today it does not.

+ +

Current behavior

+ +
    +
  • run_npu2_verify.lit exists and declares REQUIRES: ryzen_ai_npu2, peano, hf_token.
  • +
  • programming_examples/lit.cfg.py sets the hf_token lit feature only when the HF_TOKEN env var is present (so local runs without a token skip cleanly instead of failing).
  • +
  • .github/workflows/buildAndTestRyzenAI.yml runs ninja check-programming-examples-peano but does NOT inject HF_TOKEN into the job's env. As a result, lit doesn't enable the hf_token feature, and run_npu2_verify.lit is skipped on every CI run — no failure, but no actual verify either.
  • +
+ +

Proposed change

+ +
    +
  1. In .github/workflows/buildAndTestRyzenAI.yml, inject HF_TOKEN at the job (or just the lit-test step) level: +
    env:
    +  HF_TOKEN: ${{ secrets.HF_TOKEN }}
    +
  2. +
  3. In the GitHub repo settings (Settings → Secrets and variables → Actions), add a repository secret named HF_TOKEN with a read token for meta-llama/Llama-3.2-1B-Instruct (and the base model if running the MODEL=base variant in CI). Required on the fork that runs CI; if upstream wants the verify gate too, the same secret needs to be configured there.
  4. +
  5. (Optional) Cache ~/.cache/huggingface/ in the workflow to avoid re-downloading the 2.5 GB checkpoint on every run. Self-hosted runners typically persist this directory naturally, so this is only needed for ephemeral runners.
  6. +
+ +

What this buys

+ +

Every PR runs the 8-prompt × 32-token top-k inclusion gate against HF transformers bf16, end to end through the production prefill + decode kernels. ~4 min added to the existing Ryzen AI CI step. Without it, any regression in run_npu_prefill, run_npu_decode_step, the multi-launch kernel builders, or the external kernels (rope.o, silu_and_mul.o, attn_npu2.o, mv.o, mv_k8192.o) can land if its symptom is “tokens drift outside top-5” rather than a structural breakage caught by other tests.

+ +

Risk

+ +

Tiny. Adding the env var is one line; missing the secret in the env just keeps the current skip-behavior (the test fails cleanly with “REQUIRES: hf_token” not satisfied, but does not break the rest of CI).

+
+ + + +

Part E — Reference

+ +

E1. Glossary — terms defined in one place

+ +
+ +
Buffer Object (BO)
+
An XRT abstraction for a chunk of NPU-accessible memory (in DDR — the same physical RAM the host sees, but with NPU access permissions). Created by xrt.bo(device, size_bytes). Has .map() (returns a host pointer for memcpy) and .sync(direction) (cache flush + barrier). One BO per kernel argument. "Allocating a BO" is cheap; "syncing a BO" is what costs time.
+ +
Per-layer weight BO
+
A BO that holds the weight tensor for a SPECIFIC layer of the transformer. The trick: KernelCache caches BOs keyed by bo_key. When preload_prefill_weights calls load_and_run(..., bo_key="rms_gemms_rope_L5") with layer 5's wq tensor in slot 3, KernelCache allocates a fresh BO list for that key and writes the weights. Later, when inference does the same call with the same bo_key, KernelCache finds the cached BOs (already on device with the right weights), and static_input_indices={3, ...} tells it to skip writing slot 3 from host. 16 layers × 2 kernels × ~6 weight slots ≈ ~200 cached weight BOs holding ~1 GB of weights resident on device.
+ +
Static input indices (static_input_indices)
+
The set of arg slot indices that hold weights/LUTs (data that doesn't change between calls). On any call after the first for a given bo_key, these slots are skipped by the host write loop in load_and_run. The BO already has the right data from the preload call.
+ +
Intermediate indices (intermediate_indices)
+
The set of arg slot indices that hold buffers the kernel will OVERWRITE — it doesn't matter what's in them on entry. The host doesn't need to initialize them; load_and_run skips writing zeros to these slots (saves a memcpy + sync). For a multi-launch ELF, intermediate slots include both internal handoff buffers (like normed) and the final output (until the host reads it back via output_indices).
+ +
Shared intermediate BO
+
NOT a feature of production code (production uses multi-launch merging instead). This is an ablation-study concept: if you have two SEPARATE xrt.run() calls where call N's output is call N+1's input, you can manually alias call N's output BO into call N+1's input BO (via the _share_bo helper), so the data goes from device to device without a host round-trip. In the ablation it isolates "BO sharing" from "ELF merging" as separate optimizations.
+ +
Multi-launch ELF
+
One .elf binary that contains multiple air.launch operations stitched into a single func.func. Invoked by ONE xrt.run() call. The launches execute sequentially within the single XRT submission, with intermediates flowing through DDR (NPU DMA reads/writes) without CPU involvement. Saves XRT dispatch overhead and host orchestration cost.
+ +
Sub-launch
+
One air.launch operation. The 6 sub-launches in rms_gemms_rope.elf are the 6 logical kernels (RMSNorm, Q GEMM, K GEMM, V GEMM, RoPE Q, RoPE K) — each was originally a separate air.launch in its own MLIR module before stitching.
+ +
Herd
+
An AIR dialect concept: a 2D array of NPU compute tiles all running the same kernel code in parallel. E.g., air.herd @h tile(%tx, %ty) in (%sx=8, %sy=4) means an 8×4 grid of tiles. Inside an air.launch, each herd is mapped to physical AIE tiles by the air-place-herds compiler pass.
+ +
Segment
+
An AIR dialect concept above the herd: air.segment represents a partition of the NPU array. The wrapping air.launch { air.segment { air.herd { ... } } } is the canonical AIR program structure. Required so that airrt-to-npu emits airrt.segment_load ops.
+ +
aircc / aiecc
+
Two MLIR-AIR compiler drivers. aircc runs the AIR-dialect passes (dependency analysis, broadcast detection, herd placement, AIR→AIE lowering). aiecc runs the AIE-dialect passes (vectorization, routing, generates per-tile ELFs, packages into the final .elf + .insts.bin).
+ +
Peano
+
The AMD fork of LLVM that targets the AIE2P ISA. Used to compile C++ kernels (rope.cc, silu_and_mul.cc, mv.cc) into per-tile object files that get linked into the AIE ELF.
+ +
RoPE LUT
+
Pre-computed cosine/sine table for Rotary Position Embedding. generate_rope_lut in llama32_1b_weights.py builds an array of shape (max_seq, head_dim) = (2048, 64) in bf16. The first half is cos, second half is sin (concatenated, not interleaved — matches the half-split RoPE convention).
+ +
GQA (Grouped Query Attention)
+
Llama-3.2-1B has 32 Q heads but only 8 KV heads. Each KV head is shared by 4 Q heads. Reduces KV cache size 4× without much quality loss. Implemented in both NPU FA and CPU attention by indexing kv_h = h // group_size.
+ +
SwiGLU
+
The FFN activation used by Llama: SwiGLU(gate, up) = SiLU(gate) * up elementwise. Two GEMMs (gate, up) feed it; one GEMM (down) follows. Compared to GELU, requires 1 extra GEMM but learns better.
+ +
RMSNorm
+
Root-Mean-Square layer normalization: RMSNorm(x, w) = x · rsqrt(mean(x²) + ε) · w. Like LayerNorm but without the mean-subtraction and without a bias parameter. Cheaper and works equally well for transformers.
+ +
KV cache
+
Per-layer cache of K and V tensors at every token position seen so far. During decode, attention reads the entire cache (positions 0..current_pos) but only computes one new K and V (for the new token). Without it, decode would be O(N) per token instead of O(1). See Part A4.
+ +
Prefill / Decode
+
Two operating modes of LLM inference. Prefill: process the whole prompt at once (seq=N), populate KV cache. Decode: process one new token (seq=1), append to KV cache, get next token. Repeated decode generates text. See Part A3.
+ +
Padding (in this implementation)
+
NPU kernels are compiled for fixed shapes. Llama-1B's prefill kernels expect seq=2048. Shorter prompts get padded with EOS tokens up to 2048; the prefill processes all 2048 positions but only the logits at pred_pos = prompt_len - 1 are used. See Part A5.
+ +
+ +

E2. Reading guide — where to start for specific questions

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
If you want to understand…Read these in this order
The model itself (math, no NPU)1. Part A2 of this guide
2. Optionally: the original Llama paper for context
The whole pipeline end-to-end1. Makefile (entry points)
2. llama32_1b_inference.py — start with main() at the bottom, then build_session, run_once, generate, run_npu_prefill
3. llama32_1b_decode.py:run_decode_block
How weights are loaded and pre-staged1. llama32_1b_weights.pyload_weights()
2. llama32_1b_inference.py:prepare_runtime (line 129)
3. llama32_1b_inference.py:_preload_decode_weights (line 219)
4. llama32_1b_prefill.py:preload_prefill_weights
How a single ELF gets compiled1. multi_launch_builder/rms_gemms_rope_multi.py:build_rms_gemms_rope_module (line 193) — the highest-level builder
2. kernel_builder/stitching.py — text manipulation helpers
3. kernel_builder/cache.py:compile_and_cache (line 251)
4. kernel_builder/external_kernels.py — C++ .o compilation
How an ELF gets invoked at runtime1. kernel_builder/cache.py:load_and_run (line 294) — the central dispatch function
2. Any caller in llama32_1b_inference.py or llama32_1b_decode.py
3. kernel_builder/backend_presets.py — the backend kwargs dicts
How multi-launch merging works1. kernel_builder/stitching.py in full
2. multi_launch_builder/rms_gemms_rope_multi.py lines 466-481 (the stitch loop)
3. docs/explain.md for the design rationale
Why decode uses CPU attention1. llama32_1b_decode.py:decode_attention_cpu (line 96)
2. docs/issues.md for the documented limitation
3. docs/profile.md "Decode Breakdown" section
Performance breakdown / where time goes1. docs/profile.md top-to-bottom — has all the numbers
2. kernel_builder/cache.py:Profiler class (line 54)
3. Run make profile to see live numbers
How to add a new kernel-group1. Look at any multi_launch_builder/*_multi.py as a template
2. Need a build_module entry point + sub-builder calls + a stitch loop
3. Add a backend preset to kernel_builder/backend_presets.py
4. Add compile + load_and_run wiring in llama32_1b_inference.py
+ +

Quick-reference: which file does what when you grep

+ + + + + + + + + + + + + +
If you grep for…Meaningful hits in…
load_and_runcache.py (def), llama32_1b_inference.py + llama32_1b_decode.py + llama32_1b_prefill.py (callers)
bo_keycache.py (cache impl), and every preload/run call in inference scripts
static_input_indicesSame as bo_key + load_and_run
compile_and_cachecache.py (def), llama32_1b_prefill.py:compile_all_kernels, llama32_1b_decode.py:compile_decode_kernels
build_moduleEach multi_launch_builder/*_multi.py file's main entry point
_wrap_ir_in_launchstitching.py (def), used by builders that wrap bare herds
RGR_BACKEND / OGF_BACKEND / LM_GEMV_BACKENDbackend_presets.py (def), and at every call site
output_indicesThe contract document for what the caller wants back from each kernel
k_cache / v_cachellama32_1b_inference.py (allocation + prefill writes) and llama32_1b_decode.py:decode_attention_cpu (reads + appends)
pred_posllama32_1b_inference.py:run_npu_prefill — the "find last real prompt token" logic from Part A5
+ +
+ Llama-3.2-1B NPU2 production implementation guide. Last updated 2026-05.
+ Source: programming_examples/llama32_1b/ on branch llama-3.2-1B-devel.
+ Companion docs: ABLATION_STUDY.html, profile.md, explain.md, ARCHITECTURE.md.
+ Plus the spec / plan documents under programming_examples/llama32_1b/ablation/docs/. +
+ + + diff --git a/programming_examples/llama32_1b/docs/PROFILE.html b/programming_examples/llama32_1b/docs/PROFILE.html new file mode 100644 index 000000000..716562114 --- /dev/null +++ b/programming_examples/llama32_1b/docs/PROFILE.html @@ -0,0 +1,576 @@ + + + + +Llama-3.2-1B Performance Profile (NPU2) + + + + + + + +

Llama-3.2-1B Performance Profile (NPU2)

+

Per-step wall-time attribution of the production prefill + decode pipeline, end-to-end. Diagrams mirror the dataflow in IMPLEMENTATION_GUIDE.html Part B1; numbers are reproduced from a single make profile run on NPU2 (AMD Strix), seq_len=2048, MODEL=instruct.

+ + +

What make profile reports

+ +

make profile runs the same code path as make run — the production prefill + decode functions, end to end, real HuggingFace weights — and just enables the otherwise-disabled Profiler instance that cache.load_and_run already records into. There is no profile-only code path; any change to the production functions is automatically reflected in the profile.

+ +

The report (printed at the end of the run) opens with an architecture-aware dataflow summary (matches this page’s SVG order) and then dumps generic detail tables per phase (prefill / decode):

+ + + + + + + + + + +
SectionWhat it tells you
END-TO-END DATAFLOW (at the top)Architecture-aware walkthrough: tokenize → eos_pad → embed → 16×(rms_gemms_rope + flash_attn + o_ffn + kv_cache_extract) → final_norm → lm_head_gemv. Each row tagged CPU/NPU/— with measured ms. Same ordering as the SVGs in Part A / Part B below. Also prints the one-time Preprocessing (prepare_runtime) wall as a reminder.
Wall-Time AttributionHow the total wall budget splits across NPU XRT calls, CPU host ops, and the layer-loop envelope (sanity check; remainder is python scheduling).
Per-Layer ExecutionOne row per layer for prefill; aggregated avg/min/max across tokens for decode.
NPU XRT Call BreakdownEach multi-launch ELF’s wall time per invocation, plus call count. The granularity is one XRT run = one merged ELF (sub-launches inside the ELF stay opaque, since that’s how production dispatches them).
CPU Op BreakdownEach tracked CPU host operation (tokenize, eos_pad, embed lookup, KV-cache extract, final RMSNorm, decode CPU attention).
Fine-Grained NPU BreakdownEach XRT call further split into BO Write / NPU Run / BO Read (concept explained in Part C).
Per-Token Wall Trend (decode only)Per-token layer-loop wall for token 1 / middle / last + min/max/avg + first→last drift. Lets you see whether per-token latency grows with KV-cache length (decode CPU attention is O(current_pos)). With a 2048-token prompt and 30 decode tokens the drift is typically <1%.
+ +

Headline numbers

+ +

Snapshot from the report (single run, instruct model, 30 decode tokens):

+ + + + + + +
MetricWallNotes
TTFT (time-to-first-token, prefill end-to-end)~1.28 stokenize + EOS-pad + embed + 16×layer + final RMSNorm + LM head. Matches the vLLM / TGI / TRT-LLM TTFT metric (user-facing latency from request submit to first output token). 95% NPU-bound. Tokenize varies by prompt length; ~10 ms typical.
TPOT (per output token, steady-state decode)~92 ms (10.8 tok/s)16 layers × 4.95 ms each + 13.6 ms LM head + ~0.1 ms host wrappers. Slope vs token index is <1% over 30 tokens (KV cache grows by ~1.5% on a 2048-token prompt).
Preprocessing (one-time, prepare_runtime)~7.6 sCompile external kernels + pre-load weights into per-layer BOs. Happens once per process and is NOT included in TTFT.
+ +
+ CPU host op + NPU XRT call (multi-launch ELF) + FlashAttention (separate ELF, see B5) +
+ + +

Part A — Prefill (TTFT ~1.28 s)

+ +

One inference’s prefill phase: prompt → first generated token. Each box shows the step, where it runs, and the measured wall time. The 16 layers are identical; one iteration is shown in the “decoder block” container.

+ + + + + + + + + + + + Tokenize + EOS-pad to seq_len + CPU; HF chat template + tokenizer.encode + pad + ~10 ms tokenize + ~0 ms pad + + + + + + + Token embedding lookup + CPU; numpy gather + bf16 cast + ~5.8 ms + + + x: [2048, 2048] bf16 + + + + + Decoder block × L = 16 (one iteration shown; ~77.9 ms per layer; total ~1247 ms) + + + + + + rms_gemms_rope.elf — 1 xrt.run, 6 stitched launches + RMSNorm + Q/K/V GEMM + RoPE Q + RoPE K + 7.3 ms (BO write 0.5 / NPU 6.5 / BO read 0.1) + + + q_roped, k_roped, v + + + + + flash_attn.elf — 1 xrt.run, separate ELF + 1 launch; un-mergeable (see B5) + 21.6 ms (BO write 1.3 / NPU 20.1 / BO read 0.1) + + + attn_out [2048, 2048] + + + + + o_ffn.elf — 1 xrt.run, 8 stitched launches + O + Add + RMSNorm + Gate/Up + SwiGLU + Down + Add + 41.0 ms (BO write 1.0 / NPU 39.8 / BO read 0.1) + + + x_next (= next layer's input) + + + + + KV cache extract & write + CPU; reshape + transpose + slice-assign of k_roped, v + 1.1 ms per layer (×16 = 17.6 ms) + + + (loop back; 16 layers total) + + + + Per layer total: 7.3 + 21.6 + 41.0 + 1.1 = 71.0 ms (kernel+CPU) + + + Layer-loop wall: 77.9 ms → ~7 ms python/numpy scheduling overhead per layer + + + 16 layers × 77.9 ms = 1247 ms + + + + + x: [2048, 2048] after 16 layers + + + + + Final RMSNorm @ row pred_pos + CPU; only the 1 row needed for next-token argmax + 3.1 ms + + + [1, 2048] normed + + + + + lm_head_gemv.elf — 1 xrt.run, 8 partitions + Reuses decode-side ELF for the single-row projection (see A7) + 13.6 ms (BO write 0 / NPU 13.5 / BO read 0) + + + logits [1, 128256] → argmax + + + + + First generated token + + + + + TTFT (time-to-first-token): ~1280 ms + + + = 10 (tokenize) + ~0 (pad) + 5.8 (embed) + 1247 (16 layers) + 3.1 (norm) + 13.6 (LM head) ≈ 1280 ms + + + NPU XRT 1119 ms (87%) · CPU host 37 ms (3%) · python sched ~125 ms (10%, mostly inside layers) + + + + +

Prefill: per-kernel and fine-grained tables

+ + + + + + + + +
NPU XRT calls (16 layer-invocations of each, plus 1 LM head)
ELFLaunchesavg / callBO WriteNPU RunBO ReadBO MB written
rms_gemms_rope6 stitched7.3 ms0.5 ms6.5 ms0.1 ms8.0 MB
flash_attn (separate ELF)121.6 ms1.3 ms20.1 ms0.1 ms20.0 MB
o_ffn8 stitched41.0 ms1.0 ms39.8 ms0.1 ms16.0 MB
lm_head_gemv (prefill end)8 stitched13.6 ms~0 ms13.5 ms~0 ms~0 MB
+ + + + + + + + + + +
CPU host ops (prefill side)
OpCountavgTotal
tokenize1~10 ms~10 ms
eos_pad1~0 ms~0 ms
embed_lookup15.8 ms5.8 ms
kv_cache_extract161.1 ms17.6 ms
final_rms_norm13.1 ms3.1 ms
Total CPU20~37 ms
+ +

Wall-time attribution check: NPU XRT 1119 ms (16 layer-invocations × 3 kernels + 1 LM head = 49 calls) + CPU host ~37 ms = ~1156 ms accounted, vs. TTFT ~1280 ms → ~125 ms unattributed python/numpy scheduling, mostly inside the layer loop.

+ + +

Part B — Decode (per token ~92 ms)

+ +

Per-token decode step: takes the last produced token, returns the next. Diagram and numbers cover one token; the loop repeats until EOT. Each kernel reflects an avg over 30 decode tokens, 16 layers.

+ + + + + + + + + + + + Embed lookup (next-token id → row) + CPU; weights.embed_table[id].astype(bf16) + ~0 ms (single row gather) + + + x: [2048] bf16 + + + + + Decoder block × L = 16 (one iteration shown; ~5.0 ms per layer; total ~79 ms) + + + + + + rms_gemv_rope.elf — 1 xrt.run, 6 stitched launches + RMSNorm + Q/K/V GEMV + RoPE Q + RoPE K (single token) + 0.9 ms (BO write 0 / NPU 0.8 / BO read 0) + + + q_roped [2048]; k_roped, v [512] each + + + + + decode_attention_cpu + CPU single-query attention against KV cache (head_dim=64; FA NPU has too much overhead at single-query) + 0.3 ms per layer + + + attn_out [2048] + + + + + o_gemv_ffn.elf — 1 xrt.run, 8 stitched launches + O + Add + RMSNorm + Gate/Up + SwiGLU + Down + Add + 3.7 ms (BO write 0 / NPU 3.6 / BO read 0) + + + x_next (= next layer's input) + + + + append k,v at pos + + + + Per layer total: 0.9 + 0.3 + 3.7 = 4.9 ms (kernel+CPU) + + + Layer-loop wall: 4.95 ms → ~0.05 ms python/numpy overhead per layer + + + 16 layers × 4.95 ms = 79.2 ms + + + x: [2048] after 16 layers + + + + + Final RMSNorm + CPU; single row, F32 internal + 0.07 ms + + + [1, 2048] normed + + + + + lm_head_gemv.elf — 1 xrt.run, 8 partitions + 8-partition GEMV stitched in 1 ELF + 13.6 ms (NPU 13.5 dominates) + + + logits [1, 128256] → argmax + + + + + next token id + + + + + Total per-token wall: ~92 ms + + + = ~0 (embed) + 79.2 (16 layers) + 0.07 (norm) + 13.6 (LM head) ≈ 93 ms + + + NPU XRT ~85 ms (92%) · CPU host ~5 ms (5%) · LM head dominates the per-token bill at 15% + + + + +

Decode: per-kernel and fine-grained tables

+ + + + + + + +
NPU XRT calls (avg over 30 decode tokens × 16 layers)
ELFLaunchesavg / callBO WriteNPU RunBO Read
rms_gemv_rope6 stitched0.9 ms0.02 ms0.83 ms0.01 ms
o_gemv_ffn8 stitched3.7 ms0.02 ms3.64 ms0.01 ms
lm_head_gemv8 stitched13.6 ms0.01 ms13.50 ms0.03 ms
+ + + + + + + + +
CPU host ops (decode side)
OpCount / tokenavgTotal / token
decode_attention_cpu160.28 ms4.5 ms
embed_lookup1~0 ms~0 ms
final_rms_norm10.07 ms0.07 ms
Total CPU / token18~4.6 ms
+ +

Wall-time check: NPU XRT per token = 16 × (0.9 + 3.7) + 13.6 = 87.2 ms · CPU = 4.6 ms · sum 91.8 ms ≈ observed 92 ms wall. Decode is overwhelmingly NPU-bound; the LM head GEMV alone is ~15% of the per-token cost.

+ +

Observation: across decode, BO Write is <1% — this is the payoff for pre-loading all weights into per-layer BOs (and marking them static_input_indices) during prepare_runtime. Without that, each layer would re-write its 116 MB of weights per token.

+ + +

Part C — BO Write / NPU Run / BO Read explained

+ +

Each cache.load_and_run("kernel", backend, arg0, ..., argN) invocation is split into three timed segments:

+ +

1. BO Write — t_write_ms

+ +

For each input/intermediate argument that needs new bytes, the host does memcpy(numpy_data → BO.map()). Args marked static_input_indices (e.g. layer weights) skip this step on every call after prepare_runtime, so steady-state t_write_ms mainly reflects the dynamic inputs that change call-to-call (the input activation, RoPE LUT row, KV-cache slice, …).

+ +

What this measures in practice: host-to-DDR memcpy bandwidth for the dynamic inputs only. If you see this rise, either an argument lost its static_input_indices mark, or a normally-small dynamic input grew (e.g. a bigger seq_len).

+ +

2. NPU Run — t_kernel_ms

+ +

Wall time of xrt.run.start() + xrt.run.wait(). This is the NPU actually executing the multi-launch ELF: DDR → L2/L1 DMAs, AIE-tile compute, and L1/L2 → DDR DMAs of outputs. Host does nothing here except spin-wait the completion signal.

+ +

What this measures: real NPU hardware execution time for the ELF. All the multi-launch’s stitched sub-launches (e.g. RMSNorm + Q + K + V + RoPE_Q + RoPE_K inside rms_gemms_rope.elf) run sequentially on-device and are not separately resolved here — that’s by design, because production never dispatches them separately.

+ +

3. BO Read — t_read_ms

+ +

For each output argument, the host constructs a numpy view over the BO’s mapped memory using np.frombuffer(BO.map(), …). This is zero-copy — no memcpy — and consistently <0.1 ms. If t_read_ms ever climbs into the ms range, that signals an accidental copy was introduced (e.g. an .astype() on a large output).

+ +

How they sum

+ + + + + +
PhaseBO WriteNPU RunBO Read
Prefill (one full pass)~46 ms (4%)~1062 ms (95%)~5 ms (0%)
Decode (per token)~0.6 ms (1%)~86 ms (98%)~0.3 ms (0%)
+ +

Both phases are dominated by NPU Run — the host’s job is mostly to feed the right BOs and wait. Decode is even closer to pure-NPU because the per-token dynamic inputs are tiny (a single activation row vs. an entire sequence’s worth).

+ + +

How to reproduce the numbers

+ +
cd programming_examples/llama32_1b
+
+# One-time kernel compilation (~3-4 min, cached)
+make compile
+
+# Full profiling report (single run, instruct model)
+make profile N_TOKENS=30 PROMPT="Explain photosynthesis in detail."
+
+# Or with the base checkpoint
+make profile MODEL=base N_TOKENS=30 PROMPT="Once upon a time"
+
+ +

The report is printed to stdout at the end of the run. To save a copy:

+ +
make profile 2>&1 | tee profile_$(date +%Y%m%d-%H%M%S).log
+ +

Numbers will jitter ±3-5% between runs (NPU power state, OS scheduling, etc); the breakdown structure is stable. make verify is the orthogonal gate that ensures the production code path producing these numbers is still numerically correct.

+ +
+ +

+ Companion: profile.md (textual perf summary, optimization history, vs IRON comparison) · + IMPLEMENTATION_GUIDE.html B1 (same dataflow, no timing — shows just the structural picture) · + ABLATION_STUDY.html (4-cell controlled measurement of how each dispatch optimization contributes to these numbers). +

+ + + + + diff --git a/programming_examples/llama32_1b/docs/VERIFICATION.html b/programming_examples/llama32_1b/docs/VERIFICATION.html new file mode 100644 index 000000000..7c892f8b7 --- /dev/null +++ b/programming_examples/llama32_1b/docs/VERIFICATION.html @@ -0,0 +1,446 @@ + + + + +Llama-3.2-1B Verification Subsystem + + + + + + + +

Llama-3.2-1B Verification Subsystem

+

Two ways to look at the production NPU2 inference pipeline, both comparing against HuggingFace transformers in bf16. Companion to IMPLEMENTATION_GUIDE.html Part C.

+ +

Two lenses, one bf16 reference

+ +
+make verify [MODEL=instruct|base] — the industry-standard correctness gate. 8 prompts × 32 greedy tokens, top-5 set inclusion vs HuggingFace transformers bf16 on the NPU end-to-end production path (NPU FlashAttention on, no CPU attention fallback). Lite-mode runners — no inside probing. ~4 minutes / run. Default MODEL=instruct matches what production stacks deploy. +
+ +
+make diagnosis [MODEL=...] — the inside-probing lens. Single prompt's prefill, per-layer ffn_out cosine + max_abs (NPU vs HF bf16) for all 16 layers. Same end-to-end NPU production path as verify (NPU FlashAttention on). Informational only — diagnosis never fails the run. The verify gate is the correctness signal; this table is what you read by hand when verify flags an issue and you need to localize. ~2 minutes / run. +
+ +
+Why two lenses? verify answers "would this model deploy" using the exact criterion industry uses to qualify a BF16 LLM for production — discrete top-k judgment that is robust to bf16 ULP noise. diagnosis gives localization: a continuous-cosine table per layer that tells you where the NPU implementation drifts most from HF. The verify gate gates; the diagnosis lens informs. +
+ +
+Latest results (2026-05-15): +
    +
  • make verify MODEL=instruct: 8/8 PASS, ~3m41s
  • +
  • make verify MODEL=base: 8/8 PASS, ~3m39s
  • +
  • make diagnosis MODEL=instruct (NPU FA on): cos_p5 in [0.926, 0.993], U-shape with single L1-L2 dip and L10 peak.
  • +
  • make diagnosis MODEL=base (NPU FA on): cos_p5 in [0.929, 0.992], double-dip shape (L1-L3 and L12-L14). Same-checkpoint dependence on prompt + fine-tune is what diagnosis surfaces; both pass verify regardless. See Part B.
  • +
+
+ + +

A. make verify — the correctness gate

+ +

The check (mirrors vLLM's check_logprobs_close)

+ +
    +
  1. Each runner (NPU + HF bf16) greedy-decodes 32 tokens for one prompt, capturing the chosen token + top-5 token IDs at every step.
  2. +
  3. Walk both sequences in lockstep. Same chosen token → continue. Different chosen tokens → require both to appear in the OTHER side's top-5; otherwise FAIL. Stop walking after the first divergence.
  4. +
  5. All 8 prompts must pass; any FAIL exits with code 1.
  6. +
+ +

NPU runs the full production path (GEMV + RMSNorm + RoPE + FlashAttention + LM-head GEMV). Discrete top-k inclusion is robust to bf16 ULP noise: noise routinely flips per-step top-1 between mathematically equivalent implementations but rarely displaces a token from the top-5.

+ +

Two prompt sets, matched to checkpoint behavior

+ + + + + + + + + + + +
#Base (verify/prompts/base.txt)Instruct (verify/prompts/instruct.txt)
0GPU stands forIntroduce me what is GPU
1The capital of France isBriefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
2Artificial intelligence is a branch of computer science thatCompare and contrast artificial intelligence with human intelligence in terms of processing information.
3A neural network consists ofDescribe the basic components of a neural network and how it can be trained.
4Once upon a time, there was a robot who dreamed aboutWrite a short story about a robot that dreams for the first time.
5The COVID-19 pandemic, which began in late 2019,Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
6The Mona Lisa was painted byExplain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
7The French translation of "The early bird catches the worm" isTranslate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
+ +

Topics deliberately mirror each other so base-vs-instruct comparisons read naturally row-by-row. Base prompts are intentionally incomplete sentences (the base model continues raw text rather than answering instructions). Instruct prompts are imperative requests (7 verbatim from vllm/tests/prompts/example.txt + 1 swapped for project relevance).

+ +

Per-prompt results (NPU vs HF bf16, k=5)

+ +

For each prompt we display the first divergence step (0-based; step 0 is the prefill prediction, step 1 is the first decode token); each side's chosen token at that step (decoded text, quoted so leading whitespace stays visible) plus its 1-based rank in the OTHER runner's top-5; and the agreed prefix — the actual generated text both runners produced identically before splitting.

+ +

Base checkpoint

+ + + + + + + + + + +
#PromptDivergeNPU choice (rank in HF)HF choice (rank in NPU)Agreed prefix
0GPU stands for7 " special" (#2) " specialized" (#2)" Graphics Processing Unit. It is a"
1The capital of France is1 "," (#2) "." (#2)" Paris"
2Artificial intelligence is…7 "," (#2) "." (#2)" deals with the creation of intelligent machines"
3A neural network consists of3 " nodes" (#2) " interconnected" (#3)" a set of"
4Once upon a time, there was a robot…7 " little" (#2) " robot" (#2)" being a human. He was a"
5The COVID-19 pandemic…9 "," (#2) "." (#2)" has had a significant impact on the global economy"
6The Mona Lisa was painted by7 " and" (#2) "." (#3)" Leonardo da Vinci in 1503"
7The French translation…6 " prend" (#3) " g" (#2)" "Le premier oisif"
+ +

Instruct checkpoint

+ + + + + + + + + + +
#PromptDivergeNPU choice (rank in HF)HF choice (rank in NPU)Agreed prefix
0Introduce me what is GPU0 " acceleration" (#2) " (" (#2)(no prefix)
1Briefly describe…0 " Some" (#4) " Key" (#3)(no prefix)
2Compare and contrast…8 " (" (#4) " are" (#2)" Artificial intelligence (AI) and human intelligence"
3Describe the basic components…20 " multiple" (#2) " three" (#2)" \n\n## Step 1: Define the basic components of a neural network\nA neural network consists of"
4Write a short story…11 " model" (#3) " android" (#2)" It's a robot named Zeta, a highly advanced"
5Analyze the impact of COVID… (all 32 match) (all 32 match)(no divergence within sample)
6Explain the cultural significance…29 " Created" (#4) " It" (#2)" \n\nThe Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of the most famous paintings in the world."
7Translate the following…26 " Here" (#2) "The" (#2)" This is a common English idiom that means…"
+ +

Both checkpoints PASS the gate. Most divergences are #2/#2 swaps (both runners agreed on the same two top candidates; bf16 noise picked which ranked first); a few are #3/#4. None hit out-of-top-5. On Instruct, prompts 3, 6, 7 reach 20-29 tokens of agreement before splitting, and prompt 5 had zero divergence in the 32-token sample.

+ + +

B. make diagnosis — the inside lens

+ +

What it does

+ +

Single prompt's prefill on NPU + HF bf16, then computes per-position cosine + element-wise abs error for each layer's ffn_out (the block output). For layers 0…n_layers−2, both sides expose the raw layer output. For layer n_layers−1, both sides expose the post-final-RMSNorm hidden state — HF surfaces this as hidden_states[n_layers] (post-norm by HF v5.3 convention); NPU produces the equivalent via the same final_norm step it does inside its production LM-head GEMV path. So both L15 cells correspond to "the value the LM-head sees".

+ +

Diagnosis is informational only. No threshold, no pass/fail, no exit code based on the cosine. Verify is the correctness signal; the diagnosis table tells you where the NPU implementation drifts most from HF (which layer, by how much), which is what you want when triaging a real verify failure or weighing a kernel-side optimization.

+ +

Latest cosine tables (NPU FA on, prompt = "The capital of France is")

+ +

Same prompt, same NPU end-to-end path, both checkpoints. Run side-by-side so the per-layer precision shape can be compared directly.

+ +

Instruct (meta-llama/Llama-3.2-1B-Instruct)

+ + + + + + + + + + + + + + + + + + +
Layercos_p5cos_mincos_medianmax_abs
00.9932690.9932570.9937330.75
10.9264000.9081600.99095022
20.9272110.9085390.98837822
30.9406980.9276800.98820924
40.9518360.9405040.98746326
50.9593590.9501930.98815028
60.9652350.9588390.98839830
70.9692000.9649800.98805330
80.9750100.9735890.98935532
90.9815120.9806980.99048734
100.9838730.9831150.99094336
110.9811480.9788960.99044636
120.9769770.9733950.99002338
130.9753240.9709570.98989542
140.9716390.9669810.99031944
150.9706690.9663200.98750310.83
+ +

Base (meta-llama/Llama-3.2-1B)

+ + + + + + + + + + + + + + + + + + +
Layercos_p5cos_mincos_medianmax_abs
00.9919120.9912410.9940381.75
10.9660950.9595960.9896467
20.9602570.9523610.9883736
30.9589560.9505660.9861237
40.9700880.9654570.9859888
50.9727730.9694580.9855269
60.9747730.9739990.98387510
70.9719050.9688140.98266110
80.9555780.9491680.98720811
90.9604330.9591020.98953412
100.9659930.9659480.99081513
110.9549540.9491460.99097013
120.9411470.9294150.98979115
130.9367100.9231490.98886616
140.9293620.9122190.98790817
150.9394950.9242920.9903494.013
+ +

How to read it

+ +
    +
  1. Worst layer on either checkpoint is ~0.93. Comfortably inside the bf16 noise floor (NPU and HF are both bf16, so this is apples-to-apples). Cosine is direction-only, so the underlying per-position direction agreement is high across all 16 layers.
  2. +
  3. Different fine-tunes have different per-layer shapes. +
      +
    • Instruct: high at L0 (0.993), single dip at L1-L2 (~0.927), monotonic climb to a peak at L10 (0.984), gradual decline to ~0.971 by L15.
    • +
    • Base: high at L0 (0.992), early dip at L1-L3 (~0.96), small mid-stack peak at L4-L7 (~0.97), second dip reaching the floor at L12-L14 (~0.93), slight recovery at L15.
    • +
    + Different fine-tuning produces different activation distributions per layer; bf16 round-off interacts with those distributions differently. Both pass verify. +
  4. +
  5. Activation magnitude differs sharply between checkpoints. Base max_abs sits in the 6-17 range; Instruct sits in 22-44. Instruction tuning amplifies certain pathways; the bigger absolute deltas are not a precision problem (cosine is direction-only).
  6. +
  7. L15 is the post-final-norm cell. max_abs (~10 for Instruct, ~4 for base) is much smaller than mid-stack because final_norm rescales the hidden state to unit-variance-ish magnitude.
  8. +
+ + +

C. Why this design verifies production

+ +

Three things have to hold for make verify to be a meaningful correctness signal: the version we test must be the version that ships, the reference we compare against must be trustworthy, and the comparison criterion must be sound for bf16. We address each below.

+ +

1. NpuRunner runs the actual production code

+ +

NpuRunner directly imports and invokes the production functions — no reimplementation:

+ +
from llama32_1b_inference import prepare_runtime
+from llama32_1b_prefill   import run_transformer_block as run_prefill_block
+from llama32_1b_decode    import compile_decode_kernels, run_decode_block
+ +

NpuRunner.__init__ compiles the same kernels production compiles and runs the same prepare_runtime setup. NpuRunner.prefill calls run_prefill_block for each of the 16 layers, then runs the production 8-partition LM-head GEMV. NpuRunner.decode_step calls run_decode_block. If NpuRunner produces the right tokens, llama32_1b_inference.py produces the right tokens — by construction.

+ +

2. HF transformers in bf16 is the right reference

+ + + + + + +
CriterionChoice
Canonicaltransformers.AutoModelForCausalLM is the reference implementation that Meta + HuggingFace + the open-source LLM ecosystem maintain. Every bf16 LLM deployment (vLLM, llama.cpp, TRT-LLM, …) is qualified against this codebase.
Same dtypeLoaded as torch_dtype=torch.bfloat16, matching NPU production. Both sides hit the same bf16 round-off characteristics; the comparison is not testing a dtype gap.
Same weightsBoth runners load meta-llama/Llama-3.2-1B[-Instruct] from the same HF cache. Identical bytes on disk.
+ +

HfRunner is ~110 lines that delegate to self.model(input_ids, use_cache=True). No transformer-block reimplementation, no custom kernel — the simpler the reference, the harder it is for the reference to be wrong.

+ +

3. Top-k token-level inclusion is the right criterion for bf16

+ +

Continuous metrics (cosine, KL) on bf16 logits are fragile: bf16 ULP noise routinely flips per-step top-1 between two mathematically equivalent implementations. Discrete top-k inclusion is robust — bf16 noise can flip top-1 but rarely displaces a token from the top-5. compute_topk_set_check in comparators.py mirrors vLLM's tests/models/utils.py::check_logprobs_close; k=5 and n_tokens=32 are vLLM's defaults for the standard model gate.

+ +

One make verify run, end to end

+ +
+
Step 1. Load 8 prompts from verify/prompts/{instruct,base}.txt (selected by MODEL).
+
+
+
NpuRunner (production prefill + decode kernels, NPU FA on): greedy-decode 32 tokens, capturing chosen[i] + topk[i] (top-5 IDs) per step.
+
HfRunner (HF transformers in bf16): same 32-token greedy decode, same chosen[i] + topk[i] capture.
+
+
+
Step 3. compute_topk_set_check(npu_chosen, npu_topk, hf_chosen, hf_topk, k=5) walks both sequences in lockstep: +
    +
  • Same chosen → continue.
  • +
  • Different chosen → require both to land in the OTHER side's top-5; status OK or FAIL; stop.
  • +
+
+
+
Step 4. Repeat steps 2-3 for all 8 prompts; Report.has_failure() returns True iff any record is FAIL.
+
+
Step 5. Write verify_topk_token_*.{json,md}; exit 1 on FAIL else exit 0 (PASS).
+
+ +

What this catches and what it can miss

+ +

Catches (every step exercises the entire production stack):

+
    +
  • Kernel correctness regressions in GEMV / GEMM / RMSNorm / RoPE / FlashAttention / LM-head GEMV / embedding lookup — a wrong implementation shifts logits enough to push a chosen token out of HF's top-5 within 32 steps on at least one of 8 diverse prompts.
  • +
  • Pipeline glue regressions: KV-cache layout, weight pre-transpose, per-layer BO tagging, LM-head partition aggregation.
  • +
  • Fine-tune-specific behavior: gating Instruct and Base separately catches regressions on either weight distribution.
  • +
+ +

Can miss:

+
    +
  • Bugs that only manifest on prompts outside the 8 (the gate is finite; an lm-eval-harness GSM8K extension would broaden coverage).
  • +
  • Bugs that bias top-1 in a consistent direction without ever pushing a token out of top-5 (e.g., a uniform scale on every logit).
  • +
  • Code paths not exercised by the run (prompts longer than max_seq=2048, etc.).
  • +
+ +

File map

+ + + + + + + + + + + + +
FileResponsibility
Makefile (parent)verify / diagnosis / clean targets. MODEL=base|instruct, PROMPT=… for diagnosis.
verify/verify_runner.pyOrchestrator. Builds NPU + HF runners, loops prompts, calls the comparator, writes the report, exits 1 on FAIL.
verify/comparators.pytopk_token_ids (top-k with argmax-consistent tie-break), compute_topk_set_check (top-k token-level inclusion, mirrors vLLM's check_logprobs_close), plus diagnosis-only helpers (per_position_cosine, error_metrics, compare_pair).
verify/report.pyReport accumulator + JSON / markdown dumpers. has_failure() returns True iff any npu_vs_hf record is FAIL.
verify/runners/npu_runner.pyImports + invokes the production prefill / decode / LM-head functions.
verify/runners/hf_runner.pyLoads AutoModelForCausalLM in torch.bfloat16; delegates to model(input_ids, use_cache=True).
verify/runners/_records.pyPrefillRecord / DecodeStepRecord dataclasses shared by both runners.
verify/prompts/instruct.txt8 instruction-style prompts (MODEL=instruct); 7 from vllm/tests/prompts/example.txt + 1 GPU-related swap.
verify/prompts/base.txt8 continuation-style prompts (MODEL=base); incomplete sentences matched to base behavior.
+ +

Production-side touch points: llama32_1b_prefill.py::run_transformer_block populates ffn_out in the intermediates dict it already returns; diagnosis (which re-runs prefill layer-by-layer) reads it. Verify never reads any per-layer intermediates — it only consumes the final logits + chosen tokens.

+ + +

How to reproduce these numbers

+ +
cd programming_examples/llama32_1b
+
+make verify MODEL=instruct       # ~3m41s — top-k token-level inclusion gate, NPU vs HF bf16 (NPU FA on)
+make verify MODEL=base           # ~3m39s — base checkpoint, continuation prompts
+
+make diagnosis MODEL=instruct    # ~2m55s — per-layer ffn_out cosine table (NPU FA on)
+make diagnosis MODEL=base        # same lens, base checkpoint
+
+ +

Reports land in verify/reports/{verify_topk_token_,diagnosis_}YYYYMMDD-HHMMSS.{json,md} (gitignored). The chosen MODEL, model_name, and (for verify) prompts_file are recorded in the report config so the file is unambiguous.

+ +
+ +

Companion: IMPLEMENTATION_GUIDE.html Part C (the original CI smoke that this subsystem extends) · ABLATION_STUDY.html (sister study: 4-cell dispatch ablation).

+ + + + + diff --git a/programming_examples/llama32_1b/docs/explain.md b/programming_examples/llama32_1b/docs/explain.md index 58a399c81..737f7d994 100644 --- a/programming_examples/llama32_1b/docs/explain.md +++ b/programming_examples/llama32_1b/docs/explain.md @@ -249,8 +249,9 @@ The kernel exports the same `@rope` function name and signature as upstream, so no MLIR or multi-launch builder changes are needed. It is compiled to `rope.o` in `external_kernels.py:compile_rope()`. -The CPU reference (`llama32_1b_reference.py:apply_rope()`) uses the same half-split -convention, ensuring NPU and CPU produce identical results. +The NPU output is then gated against HuggingFace transformers in bf16 +(`make verify` — see [`VERIFICATION.html`](VERIFICATION.html)), +which exercises the same half-split RoPE convention end-to-end. --- diff --git a/programming_examples/llama32_1b/docs/profile.md b/programming_examples/llama32_1b/docs/profile.md index ce281b550..9550b699b 100644 --- a/programming_examples/llama32_1b/docs/profile.md +++ b/programming_examples/llama32_1b/docs/profile.md @@ -6,16 +6,28 @@ | Phase | AIR (NPU2) | IRON | Speedup | |-------|------------|------|---------| -| **Prefill** (seq_len=2048) | **1.27s wall** | 2.744s | **2.17x** | -| **Decode** (steady-state) | **92ms/token (10.8 tok/s)** | 370ms/token (2.7 tok/s) | **4.0x** | - -- **Wall time**: End-to-end from embedding to LM Head argmax (includes minimal - Python host overhead — KV-cache extraction, embedding lookup, numpy views) +| **Prefill / TTFT** (seq_len=2048) | **1.27s wall** | 2.744s | **2.17x** | +| **Decode / TPOT** (steady-state) | **92ms/token (10.8 tok/s)** | 370ms/token (2.7 tok/s) | **4.0x** | + +- **TTFT** (time-to-first-token): end-to-end from `make run` invocation to + first decoded token — includes tokenize + EOS-pad + embed + 16 layers + + final RMSNorm + LM head GEMV. Matches the vLLM / TGI / TRT-LLM TTFT + definition. With tokenize added back in, current measured TTFT is + ~1.28 s (the 1.27 s row above is the NPU-only fraction used + in the IRON comparison, since IRON does not bundle the tokenizer). +- **TPOT** (time-per-output-token): steady-state per-token decode latency + (excludes prefill / first-token cost). Drift across 30 decode tokens is + <1% — see `Per-Token Wall Trend` in `make profile` output. - **IRON baseline**: measured against the IRON reference at commit [`2b62dc7`](https://github.com/amd/IRON/commit/2b62dc77ecc72f0fa8fb3381b05579ab84778d27) of `amd/IRON`, same NPU2 hardware (Strix), same LLAMA-3.2-1B BF16 model, same `seq_len=2048`. +For the visual end-to-end dataflow with per-step measured timing and the +BO Write / NPU Run / BO Read concept walkthrough, see +[`PROFILE.html`](PROFILE.html). This file is the textual reference +(per-kernel tables, optimization history, vs IRON comparison). + **Recent optimizations** (vs. an earlier 1.54s wall headline): 1. Last-token-only LM Head: drop full-sequence NPU rmsnorm + 8-partition GEMM in prefill; do CPU rmsnorm on the 1×emb_dim last row (<1 ms) and reuse the @@ -88,13 +100,15 @@ Key differences favoring AIR: ## Prefill Breakdown (seq_len=2048, 16 layers) -### Wall Time Breakdown: 1.27s +### Wall Time Breakdown: 1.27s (NPU-only) / ~1.28s TTFT | Component | Time | Notes | |-----------|------|-------| -| **Kernel time** (sum of `load_and_run`) | ~1.16s | BO Write + NPU Run + BO Read (49 kernel calls: 16×3 transformer + 1 lm_head_gemv) | -| **Python host overhead** | ~0.11s | KV cache extraction, embedding lookup, CPU rmsnorm, numpy views | -| **Total wall time** | **1.27s** | | +| **NPU XRT calls** (sum of `load_and_run`) | ~1.12s | BO Write + NPU Run + BO Read across 49 calls: 16×3 transformer + 1 lm_head_gemv | +| **CPU host ops** (profiled) | ~37ms | tokenize + eos_pad + embed_lookup + 16×kv_cache_extract + final_rms_norm | +| **Python / numpy scheduling** | ~125ms | Per-layer dict access, numpy view setup, loop overhead (`layer-loop wall − inside-layer NPU − inside-layer CPU`) | +| **Total TTFT** (incl. tokenize) | **~1.28s** | matches `make run` Time-to-First-Token line | +| Total wall (NPU-only fraction, vs IRON) | ~1.27s | excludes tokenize; the row used in the IRON comparison | Overhead reduced from 0.67s → 0.24s by: - Suppressing print I/O in non-profile mode (4 prints × 16 layers) @@ -104,29 +118,41 @@ Overhead reduced from 0.67s → 0.24s by: - Skipping intermediate dict storage when not verifying - Removing redundant `.astype(bfloat16)` on already-bf16 kernel results -### Per-Kernel Timing +### Per-Kernel Timing (NPU XRT calls only) -| Kernel | Launches | Per-call | x Calls | Total | % | +| Kernel | Launches | Per-call | x Calls | Total | % of NPU | |--------|----------|----------|---------|-------|---| -| **o_ffn** | 8 | 41ms | 16 | **656ms** | **51%** | -| **flash_attn** | 1 | 22ms | 16 | **352ms** | **27%** | -| **lm_head** | 8 | 171ms | 1 | **171ms** | **13%** | -| **rms_gemms_rope** | 6 | 8ms | 16 | **128ms** | **10%** | -| rmsnorm | 1 | 3ms | 1 | 3ms | <1% | +| **o_ffn** | 8 (stitched) | 41.0ms | 16 | **656ms** | **59%** | +| **flash_attn** | 1 (separate ELF) | 21.6ms | 16 | **346ms** | **31%** | +| **rms_gemms_rope** | 6 (stitched) | 7.3ms | 16 | **117ms** | **10%** | +| **lm_head_gemv** | 8 partitions (stitched) | 13.6ms | 1 | **14ms** | **1%** | + +Per-CPU-op: -### Host vs NPU Breakdown (kernel time only) +| CPU op | Per-call | x Calls | Total | +|--------|----------|---------|-------| +| tokenize | ~10 ms | 1 | ~10 ms | +| eos_pad | <0.1 ms | 1 | <0.1 ms | +| embed_lookup | 5.8 ms | 1 | 5.8 ms | +| kv_cache_extract | 1.1 ms | 16 | 17.6 ms | +| final_rms_norm | 3.1 ms | 1 | 3.1 ms | + +### Host vs NPU Breakdown (XRT calls only — `cache.load_and_run` internals) | | BO Write | NPU Run | BO Read | Total | |---|----------|---------|---------|-------| -| **Sum** | 48ms | 1237ms | 9ms | 1294ms | -| **%** | **4%** | **96%** | **1%** | 100% | +| **Sum** | 46ms | 1062ms | 5ms | 1113ms | +| **%** | **4%** | **95%** | **0%** | 100% | + +(BO Read is zero-copy view construction — see PROFILE.html Part C for what +these three segments actually measure.) ### Per-Layer Data Flow ``` Layer input: x_bf16 (2048x2048, 8MB) -┌─ KERNEL 1: rms_gemms_rope (8ms/layer) ─────────────────────────┐ +┌─ KERNEL 1: rms_gemms_rope (7.3ms/layer) ───────────────────────┐ │ │ │ WRITE: x_in (8MB) ← activation, changes/layer │ │ SKIP: norm_w, wq, wk, wv ← STATIC (per-layer BO) │ @@ -142,7 +168,7 @@ Layer input: x_bf16 (2048x2048, 8MB) │ READ: v (2MB), q_roped (8MB), k_roped (2MB) │ └────────────────────────────┬────────────────────────────────────┘ ▼ -┌─ KERNEL 2: flash_attn (22ms/layer) ────────────────────────────┐ +┌─ KERNEL 2: flash_attn (21.6ms/layer) ──────────────────────────┐ │ │ │ WRITE: q_roped (8MB), k_roped (2MB), v (2MB) │ │ SKIP: attn_out ← INTERMEDIATE │ @@ -173,8 +199,11 @@ Layer input: x_bf16 (2048x2048, 8MB) └─────────────────────────────────────────────────────────────────┘ × 16 layers, then: - rmsnorm (3ms): Final layer normalization - lm_head (171ms): 8-partition GEMM → vocab logits → argmax → first token + final_rms_norm (CPU, 3.1ms): RMSNorm on single prediction-position row + lm_head_gemv (NPU, 13.6ms): 8-partition GEMV → vocab logits → argmax → first token + (reuses the decode-side 8-partition ELF; see + A7 in IMPLEMENTATION_GUIDE.html for why + full-seq GEMM was dropped in favor of single-row GEMV) ``` --- diff --git a/programming_examples/llama32_1b/docs/usage.md b/programming_examples/llama32_1b/docs/usage.md index 990e2a823..8ffc20e00 100644 --- a/programming_examples/llama32_1b/docs/usage.md +++ b/programming_examples/llama32_1b/docs/usage.md @@ -102,41 +102,64 @@ What happens internally: ### `make profile` -Same as `make run` but prints per-token timing and kernel breakdown. +Same as `make run` but enables the otherwise-disabled `Profiler` so the +end-to-end inference path is broken down into per-XRT-call and per-CPU-op +wall times. Production code path is identical to `make run`. ```bash make profile -make profile N_TOKENS=10 +make profile N_TOKENS=30 PROMPT="Explain photosynthesis in detail." ``` -Example output (with `N_TOKENS=10`): -``` -NPU prefill done in 1.27s. First token: 12366 +After the model output, the report prints (per phase: prefill / decode): + +1. **END-TO-END DATAFLOW** — architecture-aware summary in dataflow order + (tokenize → eos_pad → embed → 16×(rms_gemms_rope + flash_attn + o_ffn + + kv_cache_extract) → final_norm → lm_head_gemv → per-query total). + Mirrors the SVGs in [`PROFILE.html`](PROFILE.html). +2. **Wall-Time Attribution** — totals: NPU XRT vs CPU host ops vs layer-loop. +3. **Per-Layer Execution** — one row per prefill layer; aggregated avg/min/max + per layer across tokens for decode. +4. **NPU XRT Call Breakdown** — each multi-launch ELF, wall time per call. +5. **CPU Op Breakdown** — each tracked CPU host op (embed, kv_cache_extract, + final_rms_norm, tokenize, eos_pad, decode_attention_cpu). +6. **Fine-Grained NPU Breakdown** — each XRT call split into + `BO Write` / `NPU Run` / `BO Read` (concept explained in PROFILE.html + Part C). +7. **Per-Token Wall Trend** (decode only) — token 1 / middle / last wall + + first→last drift %, so you can spot any KV-cache-growth-driven slowdown. + +For reproduction commands + visual dataflow + concept walkthrough see +[`PROFILE.html`](PROFILE.html). + +### `make verify` -Decoding 10 tokens (token 1 to 10)... - Token 1: id=13, time=92ms - Token 2: id=1102, time=91ms - ... - Token 10: id=578, time=92ms +Top-k token-level inclusion gate against HuggingFace transformers in **bf16** +(same dtype as NPU). Greedy-decodes 8 pre-selected prompts × 32 tokens; at +each step, both runners' chosen tokens must appear in the OTHER side's top-5. +Pass/fail signal for end-to-end production correctness (~4 min). Mirrors +vLLM's `check_logprobs_close` method. -Generated 10 tokens in 0.92s -Tokens/second: 10.87 -Time/token: 92ms +```bash +make verify # default MODEL=instruct +make verify MODEL=base # base checkpoint, continuation prompts ``` -### `make verify` +Token count and `k` are fixed by the gate (32 / 5) — not user-tunable. + +### `make diagnosis` -Runs inference and compares every intermediate result against a CPU F32 reference. -Useful for validating correctness after kernel changes. +Per-layer `ffn_out` cosine + max_abs error vs HF bf16 for a single prompt. +Informational only (never fails the run); reach for it when `make verify` +flags a regression and you need to localize which layer drifted. ```bash -make verify N_TOKENS=10 +make diagnosis # uses default PROMPT +make diagnosis PROMPT="The capital of France is" ``` -Checks: -- Per-layer KV cache correlation (NPU vs CPU) -- Logits correlation at prediction position -- Top-1 token match +See [VERIFICATION.html](VERIFICATION.html) for the full design rationale, +gate criteria, and report layout. ### `make clean` @@ -175,7 +198,7 @@ llama32_1b/ ├── llama32_1b_prefill.py ← Prefill-only pipeline ├── llama32_1b_decode.py ← Decode-only pipeline ├── llama32_1b_weights.py ← Weight loading from safetensors -├── llama32_1b_reference.py ← CPU F32 reference +├── llama32_1b_cpu_helpers.py ← Small NumPy helpers: rms_norm, attention_reference, softmax │ ├── kernel_builder/ ← Shared kernel infrastructure │ ├── stitching.py ← MLIR text stitching for multi-launch ELFs @@ -212,5 +235,7 @@ llama32_1b/ **Slow first token**: The NPU enters power-save after ~10s idle. The warmup pass handles this automatically. If running manually, ensure `prepare_runtime()` is called. -**Wrong results**: Run `make verify` to compare against CPU reference. Check that -`.o` files are fresh (`make clean` then `make compile`). +**Wrong results**: Run `make verify` to gate against HuggingFace transformers +bf16 (top-k token inclusion). If verify fails, run `make diagnosis` to +localize which layer drifted. Check that `.o` files are fresh +(`make clean` then `make compile`). diff --git a/programming_examples/llama32_1b/kernel_builder/cache.py b/programming_examples/llama32_1b/kernel_builder/cache.py index d35dca937..a83df46e5 100644 --- a/programming_examples/llama32_1b/kernel_builder/cache.py +++ b/programming_examples/llama32_1b/kernel_builder/cache.py @@ -45,7 +45,6 @@ def prepare_air_project(): "attn_npu2.o", "mv.o", "mv_k8192.o", - "attn_decode_npu2.o", ]: src = Path(obj_name) if src.exists(): @@ -58,7 +57,8 @@ class Profiler: def __init__(self, enabled=False): self.enabled = enabled self.compile_times = {} # name -> seconds - self.kernel_times = {} # name -> list of seconds + self.kernel_times = {} # NPU XRT call: name -> list of seconds + self.cpu_times = {} # CPU op: name -> list of seconds self.layer_times = [] # list of (layer_idx, seconds) self.kernel_breakdowns = ( {} @@ -72,6 +72,15 @@ def record_kernel(self, name, duration): if self.enabled: self.kernel_times.setdefault(name, []).append(duration) + def record_cpu(self, name, duration): + """Record a CPU host-side operation's wall time. Use for things like + embed lookup, KV-cache extract, CPU attention fallback, final RMSNorm + — anything that is not an `xrt.run()` but consumes inference wall + time. Reported in a separate section from NPU XRT calls so the two + are easy to compare.""" + if self.enabled: + self.cpu_times.setdefault(name, []).append(duration) + def record_breakdown( self, name, write_ms, kernel_ms, read_ms, n_written, bytes_written, n_readback ): @@ -89,12 +98,45 @@ def record_breakdown( def start_layer(self): if self.enabled: - return time.time() + return time.perf_counter() return None def end_layer(self, layer_idx, t0): if self.enabled and t0 is not None: - self.layer_times.append((layer_idx, time.time() - t0)) + self.layer_times.append((layer_idx, time.perf_counter() - t0)) + + def time_cpu(self, name): + """Context manager: `with prof.time_cpu("embed_lookup"): ...` + Records the elapsed wall time as a CPU op named `name`. Safe to + use whether enabled or disabled (zero overhead when disabled).""" + prof = self + + class _Ctx: + def __enter__(self_inner): + self_inner.t0 = time.perf_counter() if prof.enabled else None + return self_inner + + def __exit__(self_inner, *exc): + if self_inner.t0 is not None: + prof.record_cpu(name, time.perf_counter() - self_inner.t0) + return False + + return _Ctx() + + def per_token_walls_ms(self, n_layers): + """Sum every consecutive `n_layers` layer-time entries into one + per-token wall (in ms). Returns [] if not enabled or no data. + Used by the dataflow summary to expose decode slowdown trends.""" + if not self.enabled or not self.layer_times: + return [] + if len(self.layer_times) % n_layers != 0: + # Shouldn't happen in a clean run; bail rather than mis-bucket. + return [] + out = [] + for tok_start in range(0, len(self.layer_times), n_layers): + chunk = self.layer_times[tok_start : tok_start + n_layers] + out.append(sum(t for _, t in chunk) * 1000.0) + return out def report(self): if not self.enabled: @@ -104,6 +146,36 @@ def report(self): print("PROFILING REPORT") print(f"{'='*60}") + # Top-level phase summary: total wall time attributed to NPU XRT + # calls vs CPU host ops vs the layer envelope. Sums won't add up + # exactly (layer envelope is the wall budget; NPU + CPU are the + # accounted-for parts inside it; remainder is python scheduling / + # numpy view setup / loop overhead). Useful as a sanity check. + if self.kernel_times or self.cpu_times or self.layer_times: + npu_total_ms = sum(t * 1000 for v in self.kernel_times.values() for t in v) + cpu_total_ms = sum(t * 1000 for v in self.cpu_times.values() for t in v) + layer_total_ms = sum(t * 1000 for _, t in self.layer_times) + npu_count = sum(len(v) for v in self.kernel_times.values()) + cpu_count = sum(len(v) for v in self.cpu_times.values()) + print(f"\n--- Wall-Time Attribution ---") + if npu_count: + print( + f" NPU XRT calls {npu_total_ms:9.2f}ms ({npu_count} calls)" + ) + if cpu_count: + print( + f" CPU host ops {cpu_total_ms:9.2f}ms ({cpu_count} calls)" + ) + if self.layer_times: + accounted = npu_total_ms + cpu_total_ms + # CPU ops happen both inside and outside the layer envelope; + # so layer_total_ms is the inside-layer wall budget, and the + # remainder vs (NPU+CPU) inside layers is python overhead. + print( + f" Layer-loop wall {layer_total_ms:9.2f}ms " + f"({len(self.layer_times)} layer-invocations)" + ) + if self.compile_times: print(f"\n--- Compilation Phase ---") total_compile = 0 @@ -115,34 +187,71 @@ def report(self): ) if self.layer_times: - print(f"\n--- Per-Layer Execution ---") + # Group by layer_idx. Prefill: each idx appears once -> one row per + # layer. Decode: each idx appears once per token -> aggregate with + # avg / min / max / count. + from collections import defaultdict + + grouped = defaultdict(list) for idx, t in self.layer_times: - print(f" Layer {idx:3d}: {t:8.2f}s") - total_layers = sum(t for _, t in self.layer_times) - print(f" {'Total prefill':40s} {total_layers:8.2f}s") + grouped[idx].append(t * 1000.0) # ms + multi_invocation = any(len(v) > 1 for v in grouped.values()) + print(f"\n--- Per-Layer Execution ---") + if multi_invocation: + for idx in sorted(grouped): + ts = grouped[idx] + print( + f" Layer {idx:3d}: avg={sum(ts)/len(ts):7.2f}ms " + f"min={min(ts):7.2f}ms max={max(ts):7.2f}ms (x{len(ts)})" + ) + else: + for idx in sorted(grouped): + print(f" Layer {idx:3d}: {grouped[idx][0]:7.2f}ms") + total_ms = sum(t * 1000.0 for _, t in self.layer_times) + print(f" {'Total layer-time':40s} {total_ms:8.2f}ms") if self.kernel_times: - print(f"\n--- Kernel Breakdown (avg per invocation) ---") + print(f"\n--- NPU XRT Call Breakdown (avg per invocation) ---") total_avg = 0 for name, times in sorted(self.kernel_times.items()): - avg = sum(times) / len(times) - total_avg += avg * len(times) - mn = min(times) - mx = max(times) - count = len(times) + times_ms = [t * 1000.0 for t in times] + avg = sum(times_ms) / len(times_ms) + total_avg += avg * len(times_ms) + count = len(times_ms) print( - f" {name:40s} avg={avg:6.3f}s " - f"min={mn:6.3f}s max={mx:6.3f}s (x{count})" + f" {name:40s} avg={avg:7.2f}ms " + f"min={min(times_ms):7.2f}ms max={max(times_ms):7.2f}ms (x{count})" ) if self.layer_times: n_layers = len(self.layer_times) - print(f" {'Total kernel time':40s} {total_avg:8.2f}s") + print(f" {'Total kernel time':40s} {total_avg:8.2f}ms") print( - f" {'Avg per layer (kernel time)':40s} {total_avg/n_layers:8.2f}s" + f" {'Avg per layer (kernel time)':40s} {total_avg/n_layers:8.2f}ms" ) + if self.cpu_times: + print(f"\n--- CPU Op Breakdown (avg per invocation) ---") + total_cpu_ms = 0 + for name, times in sorted(self.cpu_times.items()): + times_ms = [t * 1000.0 for t in times] + avg = sum(times_ms) / len(times_ms) + total_cpu_ms += avg * len(times_ms) + count = len(times_ms) + print( + f" {name:40s} avg={avg:7.2f}ms " + f"min={min(times_ms):7.2f}ms max={max(times_ms):7.2f}ms (x{count})" + ) + print(f" {'Total CPU op time':40s} {total_cpu_ms:8.2f}ms") + if self.kernel_breakdowns: - print(f"\n--- Fine-Grained Breakdown (avg per invocation) ---") + print(f"\n--- Fine-Grained NPU Breakdown (avg per invocation) ---") + print( + f" Three-segment timing of each XRT call:\n" + f" BO Write = host→DDR memcpy of dynamic inputs (weights\n" + f" pre-loaded once via static_input_indices)\n" + f" NPU Run = xrt.run.start() + wait() — actual NPU exec\n" + f" BO Read = numpy view construction (zero-copy, ~0)" + ) print( f" {'Kernel':20s} {'BO Write':>10s} {'NPU Run':>10s} {'BO Read':>10s} {'Total':>10s} {'Written':>8s} {'Read':>6s}" ) @@ -301,6 +410,7 @@ def load_and_run( static_input_indices=None, intermediate_indices=None, bo_key=None, + naive=False, ): """Load cached kernel and execute with BO reuse. @@ -316,8 +426,20 @@ def load_and_run( output_indices: Optional list of buffer indices to read back from device. If None, only the last buffer is read back (default). Use for multi-output kernels (e.g. attn_gemms: [2, 4, 6]). + static_input_indices: Optional set of buffer indices that are static + (e.g. weights, LUTs). On the first call for a given bo_key the BO is + written; on subsequent calls the host->device sync is skipped because + the kernel reads from the already-resident BO. intermediate_indices: Optional set of buffer indices that are intermediate (overwritten by kernel). Skips host->device sync. + bo_key: Optional cache key for BO reuse. Calls sharing a bo_key reuse + the same xrt.bo objects, which combined with static_input_indices + enables write-once-read-many for weights. Default uses the kernel + name (one BO set shared across all calls to that kernel). + naive: If True, force-write every input and force-read every output + on every call regardless of static_input_indices or + intermediate_indices. Used by ablation Cell A to establish a + baseline that never skips any host<->device transfer. Returns: Tuple of numpy arrays (all kernel outputs) @@ -326,6 +448,12 @@ def load_and_run( import pyxrt as xrt from air.backend.xrt import XRTBackend + if naive: + # Force-write everything, force-read everything. Used by ablation Cell A. + static_input_indices = set() + intermediate_indices = set() + output_indices = list(range(len(inputs))) + if name not in self.artifacts: raise RuntimeError( f"Kernel '{name}' not found in cache. " diff --git a/programming_examples/llama32_1b/kernel_builder/external_kernels.py b/programming_examples/llama32_1b/kernel_builder/external_kernels.py index 02287e390..3613658fc 100644 --- a/programming_examples/llama32_1b/kernel_builder/external_kernels.py +++ b/programming_examples/llama32_1b/kernel_builder/external_kernels.py @@ -12,7 +12,6 @@ """ import os -import shutil import subprocess from pathlib import Path @@ -27,28 +26,30 @@ def _get_peano_clang(): def _get_aie_include_dir(): """Find the AIE API include directory (for aie_api/aie.hpp).""" - # Primary: locate via aie-opt on PATH. Matches the convention used by - # every other Makefile in this repo (AIEOPT_DIR = $(dir $(which aie-opt))/..) - # and works for both local source builds and CI's mlir_aie wheel install. - aie_opt = shutil.which("aie-opt") - if aie_opt: - p = Path(aie_opt).resolve().parent.parent / "include" - if (p / "aie_api" / "aie.hpp").exists(): - return str(p) - # Fallback: explicit local dev install path. - p = ( + # Try mlir-aie install path relative to this file (main-repo layout) + candidates = [ Path(__file__).resolve().parent.parent.parent.parent / "my_install" / "mlir-aie" / "install" - / "include" - ) - if (p / "aie_api" / "aie.hpp").exists(): - return str(p) - raise RuntimeError( - "Cannot find aie_api/aie.hpp include directory " - "(no aie-opt on PATH and no my_install/mlir-aie/install)" - ) + / "include", + ] + # Also honour MLIR_AIE_INSTALL_DIR env var (set by env_setup.sh; works + # in git worktrees where the relative path above resolves to the worktree + # root rather than the main repo root). + mlir_aie_dir = os.environ.get("MLIR_AIE_INSTALL_DIR", "") + if mlir_aie_dir: + candidates.append(Path(mlir_aie_dir) / "include") + for p in candidates: + if (p / "aie_api" / "aie.hpp").exists(): + return str(p) + # Fallback: search from PEANO_INSTALL_DIR + peano_dir = os.environ.get("PEANO_INSTALL_DIR", "") + if peano_dir: + p = Path(peano_dir).parent.parent / "include" + if (p / "aie_api" / "aie.hpp").exists(): + return str(p) + raise RuntimeError("Cannot find aie_api/aie.hpp include directory") _PEANO_FLAGS = [ @@ -171,20 +172,6 @@ def compile_mv(tile_m=8): _compile_kernel(src, "mv.o", extra_flags=[f"-DDIM_M_OUTPUT={tile_m}"]) -def compile_attn_decode_npu2(head_dim=64): - """Compile attn_decode_npu2.o (RoPE helpers for the fused decode kernel).""" - src = _PROJ_ROOT / "attention_decode" / "attn_decode_npu2.cc" - _compile_kernel( - src, - "attn_decode_npu2.o", - extra_flags=[ - f"-DDIM_N={head_dim}", - f"-DHEAD_SIZE={head_dim}", - "-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16", - ], - ) - - def compile_all_external_kernels(head_dim=64): """Compile all external C++ kernels from source. @@ -195,6 +182,5 @@ def compile_all_external_kernels(head_dim=64): compile_silu_and_mul() compile_rope() compile_attn_npu2(head_dim=head_dim) - compile_attn_decode_npu2(head_dim=head_dim) compile_mv() compile_mv_k8192() diff --git a/programming_examples/llama32_1b/llama32_1b_cpu_helpers.py b/programming_examples/llama32_1b/llama32_1b_cpu_helpers.py new file mode 100644 index 000000000..72a854e96 --- /dev/null +++ b/programming_examples/llama32_1b/llama32_1b_cpu_helpers.py @@ -0,0 +1,88 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +"""Small NumPy CPU helpers shared by production prefill/decode + verify. + +This file used to be a full F32 CPU forward-pass implementation of the model +(plus a standalone `--verify` CLI that compared the F32 forward against HF +transformers F32). With the verify subsystem rewritten to compare directly +against HF transformers in bf16 (see verify/), that whole F32 reference +chain became redundant. What is kept here is the small set of NumPy helpers +that production still imports: + + - rms_norm : LM-head GEMV final-norm (inference.py prefill end, + and every decode step). + - attention_reference: prefill cpu_attn=True fallback (full GQA attention + in F32 on host; used when the NPU FlashAttention + kernel is unavailable for the configured head_dim). + - softmax : kept because attention_reference uses it; not + imported anywhere else. +""" + +import numpy as np + + +def rms_norm(x, weight, eps=1e-5): + """RMS normalization: x / sqrt(mean(x^2) + eps) * weight. + + Args: + x: (M, N) input array in F32. + weight: (N,) learned scale parameter. + eps: Small constant for numerical stability. + + Returns: + (M, N) normalized and scaled array in F32. + """ + x = np.asarray(x, dtype=np.float32) + weight = np.asarray(weight, dtype=np.float32) + rms = np.sqrt(np.mean(x * x, axis=-1, keepdims=True) + eps) + return (x / rms) * weight + + +def softmax(x, axis=-1): + """Numerically stable softmax (used by attention_reference).""" + x = np.asarray(x, dtype=np.float32) + x_max = np.max(x, axis=axis, keepdims=True) + exp_x = np.exp(x - x_max) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) + + +def attention_reference(q, k, v, n_heads, n_kv_heads): + """Multi-head attention with Grouped Query Attention (GQA), causal mask. + + Args: + q: (seq_len, n_heads * head_dim) -- already projected and RoPE'd. + k: (seq_len, n_kv_heads * head_dim) -- already projected and RoPE'd. + v: (seq_len, n_kv_heads * head_dim) -- already projected. + n_heads: Number of query heads. + n_kv_heads: Number of key/value heads (for GQA). + + Returns: + (seq_len, n_heads * head_dim) attention output (F32). + """ + q = np.asarray(q, dtype=np.float32) + k = np.asarray(k, dtype=np.float32) + v = np.asarray(v, dtype=np.float32) + + seq_len = q.shape[0] + head_dim = q.shape[1] // n_heads + group_size = n_heads // n_kv_heads + + # Reshape to per-head views: (seq, n_*_heads, head_dim) -> (n_*_heads, seq, head_dim) + q = q.reshape(seq_len, n_heads, head_dim).transpose(1, 0, 2) + k = k.reshape(seq_len, n_kv_heads, head_dim).transpose(1, 0, 2) + v = v.reshape(seq_len, n_kv_heads, head_dim).transpose(1, 0, 2) + + scale = 1.0 / np.sqrt(head_dim) + causal_mask = np.triu(np.full((seq_len, seq_len), -np.inf, dtype=np.float32), k=1) + + out_heads = np.empty((n_heads, seq_len, head_dim), dtype=np.float32) + for h in range(n_heads): + kv_idx = h // group_size + scores = q[h] @ k[kv_idx].T * scale + scores = scores + causal_mask + probs = softmax(scores, axis=-1) + out_heads[h] = probs @ v[kv_idx] + + # (n_heads, seq, head_dim) -> (seq, n_heads * head_dim) + return out_heads.transpose(1, 0, 2).reshape(seq_len, n_heads * head_dim) diff --git a/programming_examples/llama32_1b/llama32_1b_decode.py b/programming_examples/llama32_1b/llama32_1b_decode.py index ccb80cdee..37de7d75c 100644 --- a/programming_examples/llama32_1b/llama32_1b_decode.py +++ b/programming_examples/llama32_1b/llama32_1b_decode.py @@ -157,7 +157,7 @@ def run_decode_block( rope_lut_bf16: (max_seq, head_dim) RoPE LUT Returns: - output: (emb_dim,) — block output + output: (emb_dim,) — block output. """ emb_dim = config.emb_dim n_heads = config.n_heads @@ -232,15 +232,19 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): v_cache_layer[:, current_pos, :] = v.reshape(n_kv_heads, head_dim) # --- CPU Attention --- - attn_out = decode_attention_cpu( - q_roped.flatten(), - k_cache_layer, - v_cache_layer, - current_pos, - n_heads, - n_kv_heads, - head_dim, - ) + # Single-query attention against the growing K/V cache. CPU-side because + # at head_dim=64 the NPU FA kernel's per-call overhead dominates the + # single-query workload. + with cache.profiler.time_cpu("decode_attention_cpu"): + attn_out = decode_attention_cpu( + q_roped.flatten(), + k_cache_layer, + v_cache_layer, + current_pos, + n_heads, + n_kv_heads, + head_dim, + ) # --- Call 2: o_gemv_ffn (8 launches, 15 args) --- # O GEMV + Add + RMSNorm + Gate/Up GEMV + SiLU*mul + Down GEMV + Add @@ -281,6 +285,4 @@ def _run(name, backend, *inputs, static_indices=None, **kwargs): static_indices={0, 7, 9, 12}, intermediate_indices={2, 4, 6, 8, 10, 11, 13, 14}, ) - output = results[14].astype(bfloat16) - - return output + return results[14].astype(bfloat16) diff --git a/programming_examples/llama32_1b/llama32_1b_inference.py b/programming_examples/llama32_1b/llama32_1b_inference.py index 18c9de206..a4b768a43 100644 --- a/programming_examples/llama32_1b/llama32_1b_inference.py +++ b/programming_examples/llama32_1b/llama32_1b_inference.py @@ -17,7 +17,6 @@ # Run inference with cached kernels: python3 ../llama32_1b_inference.py --run-only --n-tokens 10 --profile python3 ../llama32_1b_inference.py --run-only --n-tokens 100 --profile - python3 ../llama32_1b_inference.py --run-only --n-tokens 5 --verify python3 ../llama32_1b_inference.py --run-only --n-tokens 20 --prompt "Once upon a time" """ @@ -37,10 +36,9 @@ from llama32_1b_weights import ( LlamaConfig, load_weights, - synthetic_weights, generate_rope_lut, ) -from kernel_builder.cache import KernelCache +from kernel_builder.cache import KernelCache, Profiler from kernel_builder.external_kernels import compile_all_external_kernels from kernel_builder.backend_presets import ( LM_GEMV_BACKEND, @@ -82,21 +80,6 @@ def _delta_text(tokenizer: Any, ids: list[int], state: _StreamState) -> str: return delta -class _SyntheticTokenizer: - """Stub tokenizer used with --synthetic-weights (no HuggingFace dependency). - - The synthetic path skips real tokenization entirely (token IDs come from a - deterministic numpy array); this stub satisfies the few attribute lookups - the pipeline still does — eos_token_id (decode-loop stop) and decode() - (verify/profile prints). - """ - - eos_token_id = -1 # never matches real token ids; decode loop runs full N - - def decode(self, ids, skip_special_tokens=False): # noqa: ARG002 - return f"" if isinstance(ids, list) else f"" - - # --------------------------------------------------------------------------- # Session: long-lived state created once per process # --------------------------------------------------------------------------- @@ -214,6 +197,10 @@ def prepare_runtime( t_prep = time.time() - t0 print(f" Runtime prepared in {t_prep:.1f}s") + # Stash on both profilers for the dataflow summary (one-time cost, + # outside per-query wall but useful context). + prefill_cache.profiler.preprocessing_s = t_prep + decode_cache.profiler.preprocessing_s = t_prep def _preload_decode_weights(decode_cache, weights, config): @@ -236,6 +223,12 @@ def _preload_decode_weights(decode_cache, weights, config): print(" Pre-loading decode weights into per-layer BOs...") + # Suppress profiling during warmup — these BO-allocate / weight-write + # calls happen in prepare_runtime (outside the user-visible wall time + # for prefill / decode). Mirrors the same guard in preload_prefill_weights. + _was_enabled = decode_cache.profiler.enabled + decode_cache.profiler.enabled = False + rope_lut_q_dummy = np.zeros(n_heads * head_dim, dtype=bfloat16) rope_lut_k_dummy = np.zeros(n_kv_heads * head_dim, dtype=bfloat16) @@ -315,6 +308,10 @@ def _preload_decode_weights(decode_cache, weights, config): intermediate_indices={2 + 2 * p for p in range(_LM_N_PARTITIONS)}, ) + # Restore profiler state — subsequent decode_cache.load_and_run calls + # (from prefill end + decode loop) record timing as intended. + decode_cache.profiler.enabled = _was_enabled + weights._decode_weights_preloaded_to_bos = True total_mb = ( config.n_layers @@ -349,13 +346,16 @@ def run_npu_prefill( tokenizer, cpu_attn=True, profile=False, - verify=False, quiet=False, ): """Run NPU prefill and extract KV cache for decode. Returns: - prefill_token: int -- first predicted token ID + prefill_token: int -- first predicted token ID (= argmax(logits_row)) + logits_row: (vocab_size,) f32 -- raw NPU LM-head logits at the + prediction position (before argmax). Production + callers can discard with `_`; the verify subsystem + reads this for top-k extraction. k_cache: (n_layers, n_kv_heads, max_seq, head_dim) bfloat16 v_cache: (n_layers, n_kv_heads, max_seq, head_dim) bfloat16 prompt_len: actual prompt length (before padding) @@ -369,9 +369,10 @@ def run_npu_prefill( k_cache = np.zeros((config.n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16) v_cache = np.zeros((config.n_layers, n_kv_heads, max_seq, head_dim), dtype=bfloat16) - # Token embedding - embed_f32 = weights.embed_table[token_ids].astype(np.float32) - x_bf16 = embed_f32.astype(bfloat16) + # Token embedding (CPU gather + dtype casts) + with prefill_cache.profiler.time_cpu("embed_lookup"): + embed_f32 = weights.embed_table[token_ids].astype(np.float32) + x_bf16 = embed_f32.astype(bfloat16) # ---- TIMED SECTION START ---- if not quiet: @@ -380,7 +381,7 @@ def run_npu_prefill( # Run 16 transformer layers on NPU, collecting KV cache for layer_idx in range(config.n_layers): - layer_t0 = time.perf_counter() if profile else None + t0 = prefill_cache.profiler.start_layer() x_bf16, intermediates = run_transformer_block( x_bf16, @@ -389,29 +390,27 @@ def run_npu_prefill( config, prefill_cache, layer_idx=layer_idx, - verify=verify, cpu_attn=cpu_attn, verbose=profile, ) - # Extract KV cache from intermediates - k_roped = intermediates["k_roped"] - v_raw = intermediates["v"] - - k_cache[layer_idx, :, :seq_len, :] = ( - k_roped.astype(bfloat16) - .reshape(seq_len, n_kv_heads, head_dim) - .transpose(1, 0, 2) - ) - v_cache[layer_idx, :, :seq_len, :] = ( - v_raw.astype(bfloat16) - .reshape(seq_len, n_kv_heads, head_dim) - .transpose(1, 0, 2) - ) + # Extract KV cache from intermediates (CPU: reshape + transpose + + # cast + slice-assign). 16 invocations per prefill, one per layer. + with prefill_cache.profiler.time_cpu("kv_cache_extract"): + k_roped = intermediates["k_roped"] + v_raw = intermediates["v"] + k_cache[layer_idx, :, :seq_len, :] = ( + k_roped.astype(bfloat16) + .reshape(seq_len, n_kv_heads, head_dim) + .transpose(1, 0, 2) + ) + v_cache[layer_idx, :, :seq_len, :] = ( + v_raw.astype(bfloat16) + .reshape(seq_len, n_kv_heads, head_dim) + .transpose(1, 0, 2) + ) - if profile: - layer_t = time.perf_counter() - layer_t0 - print(f" Layer {layer_idx:2d}: {layer_t*1000:.0f}ms") + prefill_cache.profiler.end_layer(layer_idx, t0) # Final RMSNorm + LM Head — single-position only. # Autoregressive generation only needs logits at the last real-token row; @@ -422,12 +421,14 @@ def run_npu_prefill( prompt_len = len([t for t in token_ids if t != tokenizer.eos_token_id]) pred_pos = prompt_len - 1 - from llama32_1b_reference import rms_norm as _rms_norm + from llama32_1b_cpu_helpers import rms_norm - last_hidden = np.asarray(x_bf16, dtype=np.float32)[pred_pos : pred_pos + 1] - last_normed_bf16 = ( - _rms_norm(last_hidden, weights.final_norm).flatten().astype(bfloat16) - ) + # Final RMSNorm on the single prediction-position row (CPU; <1 ms). + with prefill_cache.profiler.time_cpu("final_rms_norm"): + last_hidden = np.asarray(x_bf16, dtype=np.float32)[pred_pos : pred_pos + 1] + last_normed_bf16 = ( + rms_norm(last_hidden, weights.final_norm).flatten().astype(bfloat16) + ) # NPU LM Head GEMV — reuse the decode-cache 8-partition GEMV ELF lm_inputs = [last_normed_bf16] @@ -450,69 +451,101 @@ def run_npu_prefill( if not quiet: print(f"NPU prefill done in {t_prefill:.2f}s. First token: {prefill_token}") - # --- Verification: compare against CPU F32 reference --- - if verify: - print(f"\n{'='*60}") - print("Verification: NPU prefill vs CPU F32 reference") - print(f"{'='*60}") - from llama32_1b_reference import transformer_block as cpu_block, rms_norm + return prefill_token, logits_row, k_cache, v_cache, prompt_len - rope_lut_f32 = rope_lut_bf16[:seq_len].astype(np.float32) - x_cpu = weights.embed_table[token_ids].astype(np.float32) - for li in range(config.n_layers): - x_cpu, cpu_intermediates = cpu_block( - x_cpu, weights.layers[li], rope_lut_f32, config - ) - cpu_k = ( - cpu_intermediates["k_roped"] - .astype(np.float32) - .reshape(seq_len, n_kv_heads, head_dim) - .transpose(1, 0, 2) - ) - cpu_v = ( - cpu_intermediates["v"] - .astype(np.float32) - .reshape(seq_len, n_kv_heads, head_dim) - .transpose(1, 0, 2) - ) - npu_k = k_cache[li, :, :seq_len, :].astype(np.float32) - npu_v = v_cache[li, :, :seq_len, :].astype(np.float32) - - k_corr = np.corrcoef(npu_k.flatten(), cpu_k.flatten())[0, 1] - v_corr = np.corrcoef(npu_v.flatten(), cpu_v.flatten())[0, 1] - k_maxerr = np.max(np.abs(npu_k - cpu_k)) - v_maxerr = np.max(np.abs(npu_v - cpu_v)) - k_meanerr = np.mean(np.abs(npu_k - cpu_k)) - v_meanerr = np.mean(np.abs(npu_v - cpu_v)) - - k_status = "OK" if k_corr > 0.99 else "WARN" - v_status = "OK" if v_corr > 0.99 else "WARN" - print( - f" Layer {li:2d} K_cache: [{k_status}] corr={k_corr:.6f}, " - f"max_err={k_maxerr:.4f}, mean_err={k_meanerr:.4f}" - ) - print( - f" Layer {li:2d} V_cache: [{v_status}] corr={v_corr:.6f}, " - f"max_err={v_maxerr:.4f}, mean_err={v_meanerr:.4f}" - ) - # Compare logits - x_cpu_normed = rms_norm(x_cpu, weights.final_norm.astype(np.float32)) - cpu_logits = x_cpu_normed @ weights.lm_head.astype(np.float32).T - cpu_pred = int(np.argmax(cpu_logits[pred_pos])) - logits_f32_row = logits_row.astype(np.float32) - logit_corr = np.corrcoef(logits_f32_row, cpu_logits[pred_pos])[0, 1] - logit_maxerr = np.max(np.abs(logits_f32_row - cpu_logits[pred_pos])) - logit_meanerr = np.mean(np.abs(logits_f32_row - cpu_logits[pred_pos])) - print( - f"\n Logits (pos {pred_pos}): corr={logit_corr:.6f}, " - f"max_err={logit_maxerr:.4f}, mean_err={logit_meanerr:.4f}" +# --------------------------------------------------------------------------- +# Single decode step (one transformer block traversal + LM head) +# --------------------------------------------------------------------------- +# +# Extracted from generate()'s decode loop so the verify subsystem can call +# the exact same code path production uses, instead of reimplementing the +# loop body in NpuRunner. Pure compute — no print / timing / streaming +# state. Caller is responsible for KV-cache positioning (current_pos), for +# feeding next_token's embedding back as x_decode_bf16 on the next step, +# and for any per-token bookkeeping (timing, EOS check, streaming). + + +def run_npu_decode_step( + x_decode_bf16, + weights, + config, + decode_cache, + rope_lut_bf16, + k_cache, + v_cache, + current_pos, +): + """Run one NPU decode step: 16 transformer blocks + final RMSNorm + LM head. + + Args: + x_decode_bf16: (emb_dim,) bfloat16 — input embedding for this step. + weights, config, decode_cache, rope_lut_bf16: passed through to + run_decode_block + the LM-head GEMV. + k_cache, v_cache: shape (n_layers, n_kv_heads, max_seq, head_dim). + run_decode_block writes into [layer_idx, :, current_pos, :]. + current_pos: position to write the new K/V at (and to read prior + K/V from for attention). + + Returns: + next_token: int — argmax of the LM-head logits. + logits: (vocab_size,) f32 — raw LM-head logits (production + discards with `_`; verify reads for top-k extraction). + """ + from llama32_1b_cpu_helpers import rms_norm + + vocab_size = weights.lm_head.shape[0] + + # 16 transformer blocks on NPU. + x = x_decode_bf16.copy() + for layer_idx in range(config.n_layers): + t0 = decode_cache.profiler.start_layer() + x = run_decode_block( + x, + weights.layers[layer_idx], + decode_cache, + config, + k_cache[layer_idx], + v_cache[layer_idx], + current_pos, + rope_lut_bf16, ) - print(f" NPU top-1: {prefill_token} ({tokenizer.decode([prefill_token])})") - print(f" CPU top-1: {cpu_pred} ({tokenizer.decode([cpu_pred])})") - print(f" Match: {'YES' if prefill_token == cpu_pred else 'NO'}") + decode_cache.profiler.end_layer(layer_idx, t0) - return prefill_token, k_cache, v_cache, prompt_len + # Final RMSNorm (CPU, single row — cheap). + with decode_cache.profiler.time_cpu("final_rms_norm"): + x_normed = rms_norm( + x.astype(np.float32).reshape(1, config.emb_dim), + weights.final_norm.astype(np.float32), + ) + + # NPU LM Head: 8-partition GEMV, single XRT call. + x_lm = x_normed.flatten().astype(bfloat16) + lm_inputs = [x_lm] + lm_output_indices = [] + for p in range(_LM_N_PARTITIONS): + lm_inputs.append(weights._lm_weight_parts_gemv[p]) + lm_inputs.append(np.zeros(_LM_N_PART, dtype=bfloat16)) + lm_output_indices.append(2 + 2 * p) + lm_results = decode_cache.load_and_run( + "lm_head_gemv", + LM_GEMV_BACKEND, + *lm_inputs, + output_indices=lm_output_indices, + static_input_indices={1 + 2 * p for p in range(_LM_N_PARTITIONS)}, + intermediate_indices={2 + 2 * p for p in range(_LM_N_PARTITIONS)}, + ) + + # Assemble logits from 8 partitions. + logits = np.zeros(vocab_size, dtype=np.float32) + for p in range(_LM_N_PARTITIONS): + n_start = p * _LM_N_PART + n_end = min(n_start + _LM_N_PART, vocab_size) + logits[n_start:n_end] = lm_results[2 + 2 * p][: n_end - n_start].astype( + np.float32 + ) + next_token = int(np.argmax(logits)) + return next_token, logits # --------------------------------------------------------------------------- @@ -530,22 +563,29 @@ def generate( tokenizer, n_tokens=10, profile=False, - verify=False, cpu_attn=True, on_token=None, + ttft_start=None, ): """Run NPU prefill + NPU decode generation. Token 0 = from prefill, tokens 1+ = from decode. Both prefill and decode use NPU LM Head. - """ - from llama32_1b_reference import rms_norm + `ttft_start`, if provided, is the perf_counter() reading from the + caller before tokenization. The Time-To-First-Token (TTFT) message + measures from that point to when the first token is decoded — i.e. + tokenize + EOS-pad + NPU prefill + LM head. This matches the + standard vLLM/TGI/TRT-LLM TTFT definition (end-to-end submit → + first token). If not provided, TTFT is measured from the start + of NPU prefill only. + """ seq_len = len(prompt_tokens) - emb_dim = config.emb_dim max_seq = seq_len + n_tokens - vocab_size = weights.lm_head.shape[0] streaming = on_token is not None + ttft_includes_tokenize = ttft_start is not None + if ttft_start is None: + ttft_start = time.perf_counter() if not streaming: print(f"\n{'='*60}") @@ -553,7 +593,10 @@ def generate( print(f"{'='*60}\n") # --- Phase 1: NPU Prefill --- - prefill_token, k_cache, v_cache, prompt_len = run_npu_prefill( + # logits_row is unused in production; verify reads it via run_npu_prefill directly. + # quiet=True: the unified TTFT line below covers the user-visible timing; + # run_npu_prefill's own "NPU prefill done in X.XXs" would be redundant. + prefill_token, _logits_row, k_cache, v_cache, prompt_len = run_npu_prefill( prompt_tokens, weights, config, @@ -564,10 +607,21 @@ def generate( tokenizer=tokenizer, cpu_attn=cpu_attn, profile=profile, - verify=verify, - quiet=streaming, + quiet=True, ) + ttft = time.perf_counter() - ttft_start + if not streaming: + scope = ( + "tokenize + EOS-pad + NPU prefill + LM head" + if ttft_includes_tokenize + else "NPU prefill + LM head" + ) + print( + f"Time to first token (TTFT): {ttft:.2f}s ({scope}). " + f"First token: {prefill_token}" + ) + # --- Phase 2: NPU Decode --- generated_tokens = [prefill_token] # Token 0 = from prefill current_pos = prompt_len @@ -583,69 +637,31 @@ def generate( t_decode_start = time.time() for token_idx in range(n_tokens): - t_token_start = time.perf_counter() - - # Run 16 transformer blocks on NPU - x = x_decode.copy() - for layer_idx in range(config.n_layers): - x = run_decode_block( - x, - weights.layers[layer_idx], - decode_cache, - config, - k_cache[layer_idx], - v_cache[layer_idx], - current_pos, - rope_lut_bf16, - ) - - # Final RMSNorm (CPU) - x_normed = rms_norm( - x.astype(np.float32).reshape(1, emb_dim), - weights.final_norm.astype(np.float32), - ) - - # LM Head (NPU -- 8-partition GEMV, single XRT call) - x_lm = x_normed.flatten().astype(bfloat16) - lm_inputs = [x_lm] - lm_output_indices = [] - for p in range(_LM_N_PARTITIONS): - lm_inputs.append(weights._lm_weight_parts_gemv[p]) - lm_inputs.append(np.zeros(_LM_N_PART, dtype=bfloat16)) - lm_output_indices.append(2 + 2 * p) - lm_results = decode_cache.load_and_run( - "lm_head_gemv", - LM_GEMV_BACKEND, - *lm_inputs, - output_indices=lm_output_indices, - static_input_indices={1 + 2 * p for p in range(_LM_N_PARTITIONS)}, - intermediate_indices={2 + 2 * p for p in range(_LM_N_PARTITIONS)}, + # One decode step (16 transformer blocks + final RMSNorm + LM head). + # Verify subsystem calls the same function — keeps "what we test" and + # "what we deploy" identical. Per-layer / per-call timings are + # recorded automatically inside cache.load_and_run when the + # decode_cache's Profiler is enabled (--profile). + next_token, _logits = run_npu_decode_step( + x_decode, + weights, + config, + decode_cache, + rope_lut_bf16, + k_cache, + v_cache, + current_pos, ) - # Assemble logits from 8 partitions - logits = np.zeros((1, vocab_size), dtype=np.float32) - for p in range(_LM_N_PARTITIONS): - n_start = p * _LM_N_PART - n_end = min(n_start + _LM_N_PART, vocab_size) - logits[0, n_start:n_end] = lm_results[2 + 2 * p][: n_end - n_start].astype( - np.float32 - ) - next_token = int(np.argmax(logits[0])) - - t_token = time.perf_counter() - t_token_start - generated_tokens.append(next_token) current_pos += 1 - x_decode = weights.embed_table[next_token].astype(bfloat16) + # Embed lookup for next iteration's input (CPU). + with decode_cache.profiler.time_cpu("embed_lookup"): + x_decode = weights.embed_table[next_token].astype(bfloat16) if streaming: on_token(next_token, _delta_text(tokenizer, generated_tokens, stream_state)) - if profile: - print( - f" Token {token_idx + 1}: id={next_token}, time={t_token*1000:.0f}ms" - ) - # Stop on EOS or EOT (instruct model emits <|eot_id|> = 128009) if next_token in (tokenizer.eos_token_id, 128009): break @@ -658,9 +674,185 @@ def generate( print(f"Tokens/second: {n_generated / t_decode:.2f}") print(f"Time/token: {t_decode / n_generated * 1000:.0f}ms") + # Fine-grained profiling report. Each Profiler is a noop unless + # build_session enabled it for --profile (production path is identical + # to make run; verify path also leaves these disabled). + if prefill_cache.profiler.enabled or decode_cache.profiler.enabled: + _print_dataflow_summary( + prefill_cache, decode_cache, config.n_layers, n_generated + ) + if prefill_cache.profiler.enabled: + print(f"\n{'='*60}\nPREFILL — detail tables") + prefill_cache.profiler.report() + if decode_cache.profiler.enabled: + print(f"\n{'='*60}\nDECODE ({n_generated} tokens) — detail tables") + decode_cache.profiler.report() + return generated_tokens +def _avg(times): + return sum(times) / len(times) if times else 0.0 + + +def _print_dataflow_summary(prefill_cache, decode_cache, n_layers, n_decode_tokens): + """Architecture-aware dataflow-ordered summary that mirrors the SVG in + docs/PROFILE.html. Generic detail tables (Per-Layer / NPU XRT / CPU Op + / Fine-Grained) print after this from each Profiler.report().""" + pp = prefill_cache.profiler + dp = decode_cache.profiler + + # Convert kernel_times / cpu_times entries to ms averages. + def k_avg(prof, name): + ts = prof.kernel_times.get(name, []) + return _avg(ts) * 1000.0 + + def c_avg(prof, name): + ts = prof.cpu_times.get(name, []) + return _avg(ts) * 1000.0 + + def k_count(prof, name): + return len(prof.kernel_times.get(name, [])) + + def c_count(prof, name): + return len(prof.cpu_times.get(name, [])) + + print(f"\n{'='*68}") + print("END-TO-END DATAFLOW (per make profile, dataflow order)") + print(f"{'='*68}") + + # Preprocessing reminder (one-time setup, not per-query). + prep_s = getattr(pp, "preprocessing_s", None) + if prep_s is not None: + print( + f"\n Preprocessing (one-time, prepare_runtime): {prep_s:.1f} s" + f" ← not counted in per-query wall below" + ) + + # ---- PREFILL ---- + if pp.enabled: + print(f"\n--- PREFILL (per query, seq_len padded) ---") + rms_p = k_avg(pp, "rms_gemms_rope") + fa_p = k_avg(pp, "flash_attn") + offn_p = k_avg(pp, "o_ffn") + kv_extract = c_avg(pp, "kv_cache_extract") + layer_avg = ( + sum(t for _, t in pp.layer_times) * 1000.0 / n_layers + if pp.layer_times + else 0 + ) + layer_npu_cpu = rms_p + fa_p + offn_p + kv_extract + layer_sched = max(0.0, layer_avg - layer_npu_cpu) + tok = c_avg(pp, "tokenize") * c_count(pp, "tokenize") + pad = c_avg(pp, "eos_pad") * c_count(pp, "eos_pad") + embed = c_avg(pp, "embed_lookup") * c_count(pp, "embed_lookup") + final_n = c_avg(pp, "final_rms_norm") * c_count(pp, "final_rms_norm") + # LM head is recorded in decode_cache (production runs the prefill-end + # LM head through the same 8-partition ELF). + lm_total = sum(dp.kernel_times.get("lm_head_gemv", [])) * 1000.0 + n_lm = k_count(dp, "lm_head_gemv") + # Per-token tracking: out of N lm_head calls, 1 is the prefill end + # and N-1 are decode tokens. Approximate prefill LM head as the avg. + lm_prefill = lm_total / n_lm if n_lm else 0.0 + layer_total = layer_avg * n_layers + e2e = tok + pad + embed + layer_total + final_n + lm_prefill + + col = 38 # label column width + + def row(label, kind, ms, note=""): + print(f" {label:<{col}}{kind:<6}{ms:>8.2f} ms {note}") + + row("tokenize", "CPU", tok) + row("eos_pad", "CPU", pad) + row("embed_lookup", "CPU", embed) + print( + f" ┌─ Decoder block × {n_layers} (per layer) ─────────────────────────────┐" + ) + row(" rms_gemms_rope.elf", "NPU", rms_p) + row(" flash_attn.elf", "NPU", fa_p) + row(" o_ffn.elf", "NPU", offn_p) + row(" kv_cache_extract", "CPU", kv_extract) + row(" python/numpy scheduling", "—", layer_sched) + print(f" │ {'─'*52}") + print(f" │ {'per-layer wall':<{col-3}}{'':<6}{layer_avg:>8.2f} ms") + print(f" └──────────────────────────────────────────────────────────┘") + print( + f" {'× ' + str(n_layers) + ' layers':<{col}}{'':<6}{layer_total:>8.2f} ms" + ) + row("final_rms_norm", "CPU", final_n) + row("lm_head_gemv.elf", "NPU", lm_prefill) + print(f" {'─'*60}") + print(f" {'End-to-end (prefill, per query)':<{col}}{'':<6}{e2e:>8.2f} ms") + + # ---- DECODE ---- + if dp.enabled and n_decode_tokens > 0: + print(f"\n--- DECODE (avg per token, {n_decode_tokens} tokens) ---") + rms_d = k_avg(dp, "rms_gemv_rope") + ogf_d = k_avg(dp, "o_gemv_ffn") + dec_attn = c_avg(dp, "decode_attention_cpu") + embed_d = c_avg(dp, "embed_lookup") + final_d = c_avg(dp, "final_rms_norm") + lm_d = k_avg(dp, "lm_head_gemv") + layer_d = ( + sum(t for _, t in dp.layer_times) * 1000.0 / (n_layers * n_decode_tokens) + if dp.layer_times + else 0 + ) + layer_d_sub = rms_d + ogf_d + dec_attn + layer_d_sched = max(0.0, layer_d - layer_d_sub) + e2e_d = embed_d + layer_d * n_layers + final_d + lm_d + + col = 38 + + def row(label, kind, ms, note=""): + print(f" {label:<{col}}{kind:<6}{ms:>8.2f} ms {note}") + + row("embed_lookup", "CPU", embed_d) + print( + f" ┌─ Decoder block × {n_layers} (per layer, per token) ─────────────────┐" + ) + row(" rms_gemv_rope.elf", "NPU", rms_d) + row(" decode_attention_cpu", "CPU", dec_attn) + row(" o_gemv_ffn.elf", "NPU", ogf_d) + row(" python/numpy scheduling", "—", layer_d_sched) + print(f" │ {'─'*52}") + print(f" │ {'per-layer wall':<{col-3}}{'':<6}{layer_d:>8.2f} ms") + print(f" └──────────────────────────────────────────────────────────┘") + print( + f" {'× ' + str(n_layers) + ' layers':<{col}}{'':<6}{layer_d * n_layers:>8.2f} ms" + ) + row("final_rms_norm", "CPU", final_d) + row("lm_head_gemv.elf", "NPU", lm_d) + print(f" {'─'*60}") + print(f" {'End-to-end (per token)':<{col}}{'':<6}{e2e_d:>8.2f} ms") + + # Per-token trend: did wall time grow with token index? (decode CPU + # attention is O(current_pos), but with 2048-token prompt the slope + # is usually invisible for short generations.) + walls = dp.per_token_walls_ms(n_layers) + if len(walls) >= 3: + avg_w = sum(walls) / len(walls) + mn = min(walls) + mx = max(walls) + # Show first/middle/last samples for the slope. + first = walls[0] + mid = walls[len(walls) // 2] + last = walls[-1] + slope = last - first + slope_pct = (slope / first * 100.0) if first else 0 + print( + f"\n Per-token layer-loop wall trend (decode-attention CPU scales with KV cache size):" + ) + print( + f" token 1 = {first:6.2f} ms token {len(walls)//2 + 1:2d} = {mid:6.2f} ms " + f"token {len(walls):2d} = {last:6.2f} ms" + ) + print( + f" min = {mn:6.2f} ms max = {mx:6.2f} ms avg = {avg_w:6.2f} ms " + f"first→last drift = {slope:+.2f} ms ({slope_pct:+.1f}%)" + ) + + # --------------------------------------------------------------------------- # Session lifecycle and per-turn execution # --------------------------------------------------------------------------- @@ -674,8 +866,20 @@ def build_session(args) -> Session: config = LlamaConfig() seq_len = 2048 - prefill_cache = KernelCache("prefill_kernel_cache", verbose=args.verbose) - decode_cache = KernelCache("decode_kernel_cache", verbose=args.verbose) + # Each cache gets its own Profiler so the final report can separate + # prefill from decode phases. Profilers are enabled only under + # --profile; otherwise every record_* call is a noop (production + # path is identical to make run). + prefill_cache = KernelCache( + "prefill_kernel_cache", + verbose=args.verbose, + profiler=Profiler(enabled=args.profile), + ) + decode_cache = KernelCache( + "decode_kernel_cache", + verbose=args.verbose, + profiler=Profiler(enabled=args.profile), + ) if not args.run_only: print("Compiling prefill kernels...") @@ -690,22 +894,17 @@ def build_session(args) -> Session: prefill_cache.load_manifest() decode_cache.load_manifest() - if args.synthetic_weights: - print("\nUsing synthetic random weights (skipping HuggingFace download).") - weights = synthetic_weights(config) - tokenizer = _SyntheticTokenizer() - else: - model_id = ( - "meta-llama/Llama-3.2-1B-Instruct" - if args.model == "instruct" - else "meta-llama/Llama-3.2-1B" - ) - print(f"\nLoading weights ({model_id})...") - weights = load_weights(model_id) + model_id = ( + "meta-llama/Llama-3.2-1B-Instruct" + if args.model == "instruct" + else "meta-llama/Llama-3.2-1B" + ) + print(f"\nLoading weights ({model_id})...") + weights = load_weights(model_id) - from transformers import AutoTokenizer + from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) rope_lut_bf16 = generate_rope_lut( config=config, @@ -745,18 +944,25 @@ def run_once( *, n_tokens: int, profile: bool = False, - verify: bool = False, cpu_attn: bool = True, on_token: Optional[Callable[[int, str], None]] = None, ) -> tuple[list, int]: """Tokenize, pad to seq_len, and call generate(). Returns (generated_token_ids, prompt_len_actual).""" - tokens = _tokenize_prompt(session, prompt_text) + # Tokenize + EOS-pad are part of the per-query critical path (standard + # TTFT scope per vLLM/TGI/TRT-LLM), so we time them with the rest of + # prefill: ttft_start is captured BEFORE tokenize, then handed to + # generate(), which prints the unified "Time to first token (TTFT)" + # line covering tokenize + EOS-pad + NPU prefill + LM head. + ttft_start = time.perf_counter() + with session.prefill_cache.profiler.time_cpu("tokenize"): + tokens = _tokenize_prompt(session, prompt_text) prompt_len_actual = len(tokens) - if len(tokens) < session.seq_len: - tokens = tokens + [session.tokenizer.eos_token_id] * ( - session.seq_len - len(tokens) - ) + with session.prefill_cache.profiler.time_cpu("eos_pad"): + if len(tokens) < session.seq_len: + tokens = tokens + [session.tokenizer.eos_token_id] * ( + session.seq_len - len(tokens) + ) generated = generate( tokens, @@ -768,9 +974,9 @@ def run_once( tokenizer=session.tokenizer, n_tokens=n_tokens, profile=profile, - verify=verify, cpu_attn=cpu_attn, on_token=on_token, + ttft_start=ttft_start, ) return generated, prompt_len_actual @@ -833,11 +1039,9 @@ def _stream_cb(_token_id: int, delta: str) -> None: session, prompt, n_tokens=args.n_tokens, - # profile/verify are forced to False by the --interactive - # mutex block in __main__; pass through as the single source - # of truth. + # profile is forced to False by the --interactive mutex + # block in __main__; pass through as the single source of truth. profile=args.profile, - verify=args.verify, cpu_attn=args.cpu_attn, on_token=_stream_cb, ) @@ -877,11 +1081,6 @@ def _stream_cb(_token_id: int, delta: str) -> None: action="store_true", help="Enable per-token timing instrumentation", ) - parser.add_argument( - "--verify", - action="store_true", - help="Compare against CPU F32 reference", - ) parser.add_argument( "--cpu-attn", action="store_true", @@ -905,17 +1104,8 @@ def _stream_cb(_token_id: int, delta: str) -> None: action="store_true", help="Drop into a REPL after runtime prep. Loops on prompts; each is independent.", ) - parser.add_argument( - "--synthetic-weights", - action="store_true", - help="Use deterministic random weights instead of HuggingFace weights " - "(no download / no auth). Intended for CI smoke + verify tests.", - ) args = parser.parse_args() - if args.synthetic_weights and args.interactive: - parser.error("--synthetic-weights cannot be combined with --interactive") - if args.interactive: if args.compile_only: parser.error("--interactive cannot be combined with --compile-only") @@ -932,44 +1122,17 @@ def _stream_cb(_token_id: int, delta: str) -> None: file=sys.stderr, ) args.profile = False - if args.verify: - print( - "WARNING: --verify is ignored in --interactive mode.", - file=sys.stderr, - ) - args.verify = False session = build_session(args) if args.interactive: repl_loop(session, args) - elif args.synthetic_weights: - # Bypass real tokenization: feed a deterministic token-id sequence - # straight into generate(). Output text is not meaningful — the value - # of this path is the --verify correlation against the CPU reference. - token_ids = ( - np.arange(session.seq_len, dtype=np.int64) % session.config.vocab_size - ).tolist() - generate( - token_ids, - session.weights, - session.config, - session.prefill_cache, - session.decode_cache, - session.rope_lut_bf16, - tokenizer=session.tokenizer, - n_tokens=args.n_tokens, - profile=args.profile, - verify=args.verify, - cpu_attn=args.cpu_attn, - ) else: generated, prompt_len_actual = run_once( session, args.prompt, n_tokens=args.n_tokens, profile=args.profile, - verify=args.verify, cpu_attn=args.cpu_attn, ) _print_one_shot_output(session, args.prompt, generated, prompt_len_actual) diff --git a/programming_examples/llama32_1b/llama32_1b_prefill.py b/programming_examples/llama32_1b/llama32_1b_prefill.py index 53d4641d9..db748e1e8 100644 --- a/programming_examples/llama32_1b/llama32_1b_prefill.py +++ b/programming_examples/llama32_1b/llama32_1b_prefill.py @@ -41,12 +41,7 @@ sys.path.insert(0, _PROG_EXAMPLES) from llama32_1b_weights import LlamaConfig, load_weights, generate_rope_lut -from llama32_1b_reference import ( - rms_norm as rms_norm_ref, - apply_rope as apply_rope_ref, - attention_reference, - ffn_full_reference, -) +from llama32_1b_cpu_helpers import attention_reference from kernel_builder.cache import KernelCache, Profiler from kernel_builder.backend_presets import ( SIMPLE_BACKEND, @@ -167,16 +162,13 @@ def compile_all_kernels(cache, config, seq_len, cpu_attn=True): # --------------------------------------------------------------------------- -def _attn_backend_kwargs(head_dim): - lkp = head_dim - enable_shared_buffers = lkp == head_dim - return { - "omit_while_true_loop": not enable_shared_buffers, - "omit_pingpong": "all", - "runtime_loop_tiling_sizes": [1, 1], - "output_format": "elf", - "instance_name": "attention_bf16", - } +_ATTN_BACKEND_KWARGS = { + "omit_while_true_loop": False, + "omit_pingpong": "all", + "runtime_loop_tiling_sizes": [1, 1], + "output_format": "elf", + "instance_name": "attention_bf16", +} def run_transformer_block( @@ -186,7 +178,6 @@ def run_transformer_block( config, cache, layer_idx=0, - verify=False, cpu_attn=True, verbose=False, ): @@ -199,7 +190,6 @@ def run_transformer_block( config: LlamaConfig cache: KernelCache instance (kernels must be pre-compiled) layer_idx: Layer index for logging - verify: If True, compare each intermediate against CPU reference cpu_attn: If True, use CPU attention fallback instead of NPU kernel verbose: If True, print per-step progress @@ -221,23 +211,6 @@ def run_transformer_block( _arg_cache = getattr(run_transformer_block, "_arg_cache", {}) run_transformer_block._arg_cache = _arg_cache - def _compare(name, npu_result, cpu_ref=None): - """Compare NPU result against a per-step CPU reference.""" - intermediates[name] = npu_result - if cpu_ref is not None: - npu_f32 = npu_result.astype(np.float32).flatten() - ref_f32 = np.asarray(cpu_ref, dtype=np.float32).flatten() - if npu_f32.shape == ref_f32.shape: - abs_err = np.max(np.abs(npu_f32 - ref_f32)) - denom = np.maximum(np.abs(ref_f32), 1e-6) - rel_err = np.mean(np.abs(npu_f32 - ref_f32) / denom) - corr = np.corrcoef(npu_f32, ref_f32)[0, 1] if len(npu_f32) > 1 else 1.0 - status = "OK" if corr > 0.99 else "WARN" - print( - f" [{status}] {name}: max_err={abs_err:.4f}, " - f"mean_rel={rel_err:.4f}, corr={corr:.6f}" - ) - if verbose: print(f" Layer {layer_idx}: Running transformer block...") @@ -281,28 +254,11 @@ def _compare(name, npu_result, cpu_ref=None): v = results[8].reshape(seq_len, kv_dim) q_roped = results[11].reshape(seq_len, n_heads * head_dim) k_roped = results[12].reshape(seq_len, n_kv_heads * head_dim) - # Store v and k_roped — needed by caller for KV cache extraction + # Store per-probe intermediates — used by KV-cache extraction (v, k_roped) + # AND by verify/runners/npu_runner.py to capture per-probe NPU outputs. intermediates["v"] = v intermediates["k_roped"] = k_roped - if verify: - normed_ref = rms_norm_ref(x_bf16.astype(np.float32), layer_weights.attn_norm) - ref_v = normed_ref @ np.asarray(layer_weights.wv, dtype=np.float32) - _compare("v", v, ref_v) - ref_q = normed_ref @ np.asarray(layer_weights.wq, dtype=np.float32) - ref_k = normed_ref @ np.asarray(layer_weights.wk, dtype=np.float32) - lut_f32 = rope_lut_bf16[:seq_len].astype(np.float32) - q_heads_f32 = ref_q.reshape(seq_len, n_heads, head_dim) - ref_q_roped = np.empty_like(q_heads_f32) - for h in range(n_heads): - ref_q_roped[:, h, :] = apply_rope_ref(q_heads_f32[:, h, :], lut_f32) - _compare("q_roped", q_roped, ref_q_roped.reshape(seq_len, n_heads * head_dim)) - k_heads_f32 = ref_k.reshape(seq_len, n_kv_heads, head_dim) - ref_k_roped = np.empty_like(k_heads_f32) - for h in range(n_kv_heads): - ref_k_roped[:, h, :] = apply_rope_ref(k_heads_f32[:, h, :], lut_f32) - _compare( - "k_roped", k_roped, ref_k_roped.reshape(seq_len, n_kv_heads * head_dim) - ) + intermediates["q_roped"] = q_roped # 7. Flash Attention GQA if cpu_attn: @@ -310,13 +266,14 @@ def _compare(name, npu_result, cpu_ref=None): print( f" Step 7: Attention GQA [CPU fallback] ({n_heads}Q/{n_kv_heads}KV heads)" ) - attn_out = attention_reference( - q_roped.astype(np.float32), - k_roped.astype(np.float32), - v.astype(np.float32), - n_heads, - n_kv_heads, - ).astype(bfloat16) + with cache.profiler.time_cpu("prefill_cpu_attention"): + attn_out = attention_reference( + q_roped.astype(np.float32), + k_roped.astype(np.float32), + v.astype(np.float32), + n_heads, + n_kv_heads, + ).astype(bfloat16) else: if verbose: print( @@ -326,16 +283,16 @@ def _compare(name, npu_result, cpu_ref=None): k_attn = np.ascontiguousarray(k_roped) v_attn = np.ascontiguousarray(v) attn_output = np.zeros((seq_len, n_heads * head_dim), dtype=bfloat16) - attn_bk = _attn_backend_kwargs(head_dim) results = cache.load_and_run( "flash_attn", - attn_bk, + _ATTN_BACKEND_KWARGS, q_attn, k_attn, v_attn, attn_output, ) attn_out = results[-1].reshape(seq_len, n_heads * head_dim) + intermediates["attn_out"] = attn_out # 8-15. O GEMM + Residual Add + FFN [8-launch multi-launch ELF] if verbose: @@ -386,19 +343,7 @@ def _compare(name, npu_result, cpu_ref=None): bo_key=_offn_key, ) output_bf16 = results[14].reshape(seq_len, emb_dim) - if verify: - proj_ref = attn_out.astype(np.float32) @ np.asarray( - layer_weights.wo, dtype=np.float32 - ) - res1_ref = x_bf16.astype(np.float32) + proj_ref - ref = ffn_full_reference( - res1_ref.astype(bfloat16), - layer_weights.ffn_norm, - layer_weights.w_gate, - layer_weights.w_up, - layer_weights.w_down, - ).reshape(seq_len, emb_dim) - _compare("output", output_bf16, ref) + intermediates["ffn_out"] = output_bf16 return output_bf16, intermediates diff --git a/programming_examples/llama32_1b/llama32_1b_reference.py b/programming_examples/llama32_1b/llama32_1b_reference.py deleted file mode 100644 index 1834b91f8..000000000 --- a/programming_examples/llama32_1b/llama32_1b_reference.py +++ /dev/null @@ -1,480 +0,0 @@ -# Copyright (C) 2026, Advanced Micro Devices, Inc. -# SPDX-License-Identifier: MIT - -"""CPU reference implementation of LLAMA-3.2-1B forward pass. - -Pure NumPy in F32 for numerical verification against NPU results. -All intermediate computations are done in F32 (weights are cast from BF16 -at use time) to provide a high-accuracy reference. - -LLAMA-3.2-1B config: - 16 layers, emb_dim=2048, n_heads=32, head_dim=64, n_kv_heads=8, - hidden_dim=8192, vocab_size=128256, BF16, rope_base=500000 -""" - -import argparse -import numpy as np -from ml_dtypes import bfloat16 - -from llama32_1b_weights import ( - LlamaConfig, - LayerWeights, - LlamaWeights, - load_weights, - generate_rope_lut, -) - - -def rms_norm(x, weight, eps=1e-5): - """RMS normalization: x / sqrt(mean(x^2) + eps) * weight. - - Args: - x: (M, N) input array in F32. - weight: (N,) learned scale parameter. - eps: Small constant for numerical stability. - - Returns: - (M, N) normalized and scaled array in F32. - """ - x = np.asarray(x, dtype=np.float32) - weight = np.asarray(weight, dtype=np.float32) - # Compute RMS per row - rms = np.sqrt(np.mean(x * x, axis=-1, keepdims=True) + eps) - return (x / rms) * weight - - -def apply_rope(x, lut): - """Apply Rotary Position Embedding using a precomputed LUT. - - Uses half-split convention (matching HuggingFace Llama): - pairs (x[i], x[i + dim//2]) with rotation angle theta_i. - - LUT layout: [cos_0, ..., cos_{half-1}, sin_0, ..., sin_{half-1}] - - Args: - x: (seq_len, head_dim) input for one head. - lut: (seq_len, head_dim) with concatenated [cos..., sin...]. - - Returns: - (seq_len, head_dim) with RoPE applied. - """ - x = np.asarray(x, dtype=np.float32) - lut = np.asarray(lut, dtype=np.float32) - dim = x.shape[-1] - half = dim // 2 - - cos_vals = lut[:, :half] - sin_vals = lut[:, half:] - - x1 = x[:, :half] - x2 = x[:, half:] - - out = np.empty_like(x) - out[:, :half] = x1 * cos_vals - x2 * sin_vals - out[:, half:] = x1 * sin_vals + x2 * cos_vals - return out - - -def silu(x): - """SiLU activation: x * sigmoid(x). - - Args: - x: Input array (any shape) in F32. - - Returns: - SiLU-activated array with the same shape. - """ - x = np.asarray(x, dtype=np.float32) - return x * (1.0 / (1.0 + np.exp(-x))) - - -def swiglu(gate, up): - """SwiGLU gating: SiLU(gate) * up. - - Args: - gate: Gate input array in F32. - up: Up-projection input array in F32. - - Returns: - Element-wise SiLU(gate) * up. - """ - return silu(gate) * np.asarray(up, dtype=np.float32) - - -def ffn_full_reference(x, ffn_norm_weight, w_gate, w_up, w_down, eps=1e-5): - """CPU F32 reference for the full FFN block: - RMSNorm -> Gate -> Up -> SwiGLU -> Down -> Residual Add. - - Args: - x: (seq_len, emb_dim) input (residual state) - ffn_norm_weight: (emb_dim,) RMSNorm weight - w_gate: (emb_dim, hidden_dim) gate projection weight - w_up: (emb_dim, hidden_dim) up projection weight - w_down: (hidden_dim, emb_dim) down projection weight - eps: RMSNorm epsilon - - Returns: - (seq_len, emb_dim) bfloat16: x + down_proj(SwiGLU(gate, up)) - """ - x_f32 = x.astype(np.float32) - normed = rms_norm(x_f32, ffn_norm_weight, eps) - gate = normed @ w_gate.astype(np.float32) - up = normed @ w_up.astype(np.float32) - down = swiglu(gate, up) @ w_down.astype(np.float32) - return (x_f32 + down).astype(bfloat16) - - -def softmax(x, axis=-1): - """Numerically stable softmax. - - Args: - x: Input array in F32. - axis: Axis along which to compute softmax. - - Returns: - Softmax probabilities with the same shape as x. - """ - x = np.asarray(x, dtype=np.float32) - x_max = np.max(x, axis=axis, keepdims=True) - exp_x = np.exp(x - x_max) - return exp_x / np.sum(exp_x, axis=axis, keepdims=True) - - -def attention_reference(q, k, v, n_heads, n_kv_heads): - """Multi-head attention with Grouped Query Attention (GQA). - - Args: - q: (seq_len, n_heads * head_dim) -- already projected and RoPE'd. - k: (seq_len, n_kv_heads * head_dim) -- already projected and RoPE'd. - v: (seq_len, n_kv_heads * head_dim) -- already projected. - n_heads: Number of query heads. - n_kv_heads: Number of key/value heads (for GQA). - - Returns: - (seq_len, n_heads * head_dim) attention output. - """ - q = np.asarray(q, dtype=np.float32) - k = np.asarray(k, dtype=np.float32) - v = np.asarray(v, dtype=np.float32) - - seq_len = q.shape[0] - head_dim = q.shape[1] // n_heads - group_size = n_heads // n_kv_heads - - # Reshape to per-head views - # q: (seq_len, n_heads, head_dim) -> (n_heads, seq_len, head_dim) - q = q.reshape(seq_len, n_heads, head_dim).transpose(1, 0, 2) - # k: (seq_len, n_kv_heads, head_dim) -> (n_kv_heads, seq_len, head_dim) - k = k.reshape(seq_len, n_kv_heads, head_dim).transpose(1, 0, 2) - # v: (seq_len, n_kv_heads, head_dim) -> (n_kv_heads, seq_len, head_dim) - v = v.reshape(seq_len, n_kv_heads, head_dim).transpose(1, 0, 2) - - scale = 1.0 / np.sqrt(head_dim) - - # Causal mask: mask[i][j] = 0 if j <= i, else -inf - causal_mask = np.triu(np.full((seq_len, seq_len), -np.inf, dtype=np.float32), k=1) - - # Compute attention for each query head - out_heads = np.empty((n_heads, seq_len, head_dim), dtype=np.float32) - for h in range(n_heads): - kv_idx = h // group_size - # scores: (seq_len, seq_len) - scores = q[h] @ k[kv_idx].T * scale - scores = scores + causal_mask - probs = softmax(scores, axis=-1) - out_heads[h] = probs @ v[kv_idx] - - # Reshape back: (n_heads, seq_len, head_dim) -> (seq_len, n_heads * head_dim) - out = out_heads.transpose(1, 0, 2).reshape(seq_len, n_heads * head_dim) - return out - - -def transformer_block(x, layer_weights, rope_lut, config): - """Single transformer block with attention and FFN. - - Args: - x: (seq_len, emb_dim) input in F32. - layer_weights: LayerWeights for this layer. - rope_lut: (seq_len, head_dim) RoPE lookup table. - config: LlamaConfig with model hyperparameters. - - Returns: - (output, intermediates) where output is (seq_len, emb_dim) in F32 - and intermediates is a dict mapping step names to arrays. - """ - x = np.asarray(x, dtype=np.float32) - intermediates = {} - seq_len = x.shape[0] - n_heads = config.n_heads - n_kv_heads = config.n_kv_heads - head_dim = config.head_dim - - # --- Self-attention --- - - # 1. Pre-attention RMS norm - normed = rms_norm(x, layer_weights.attn_norm) - intermediates["attn_norm"] = normed - - # 2-4. QKV projections - wq = np.asarray(layer_weights.wq, dtype=np.float32) - wk = np.asarray(layer_weights.wk, dtype=np.float32) - wv = np.asarray(layer_weights.wv, dtype=np.float32) - q = normed @ wq # (seq_len, n_heads * head_dim) = (seq_len, 2048) - k = normed @ wk # (seq_len, n_kv_heads * head_dim) = (seq_len, 512) - v = normed @ wv # (seq_len, n_kv_heads * head_dim) = (seq_len, 512) - intermediates["q"] = q - intermediates["k"] = k - intermediates["v"] = v - - # 5. Apply RoPE to Q (per-head) - # Reshape Q: (seq_len, n_heads, head_dim) -> process each head independently - q_heads = q.reshape(seq_len, n_heads, head_dim) - q_roped_heads = np.empty_like(q_heads) - for h in range(n_heads): - q_roped_heads[:, h, :] = apply_rope( - q_heads[:, h, :].reshape(seq_len, head_dim), rope_lut[:seq_len] - ) - q_roped = q_roped_heads.reshape(seq_len, n_heads * head_dim) - intermediates["q_roped"] = q_roped - - # 6. Apply RoPE to K (per-head) - k_heads = k.reshape(seq_len, n_kv_heads, head_dim) - k_roped_heads = np.empty_like(k_heads) - for h in range(n_kv_heads): - k_roped_heads[:, h, :] = apply_rope( - k_heads[:, h, :].reshape(seq_len, head_dim), rope_lut[:seq_len] - ) - k_roped = k_roped_heads.reshape(seq_len, n_kv_heads * head_dim) - intermediates["k_roped"] = k_roped - - # 7. Attention - attn_out = attention_reference(q_roped, k_roped, v, n_heads, n_kv_heads) - intermediates["attn_out"] = attn_out - - # 8. Output projection - wo = np.asarray(layer_weights.wo, dtype=np.float32) - proj = attn_out @ wo # (seq_len, emb_dim) - intermediates["proj"] = proj - - # 9. Residual connection - res1 = x + proj - intermediates["res1"] = res1 - - # --- Feed-forward network --- - - # 10. Pre-FFN RMS norm - normed2 = rms_norm(res1, layer_weights.ffn_norm) - intermediates["ffn_norm"] = normed2 - - # 11-12. Gate and Up projections - w_gate = np.asarray(layer_weights.w_gate, dtype=np.float32) - w_up = np.asarray(layer_weights.w_up, dtype=np.float32) - gate = normed2 @ w_gate # (seq_len, hidden_dim) = (seq_len, 8192) - up = normed2 @ w_up # (seq_len, hidden_dim) = (seq_len, 8192) - intermediates["gate"] = gate - intermediates["up"] = up - - # 13. SwiGLU activation - swiglu_out = swiglu(gate, up) - intermediates["swiglu"] = swiglu_out - - # 14. Down projection - w_down = np.asarray(layer_weights.w_down, dtype=np.float32) - down = swiglu_out @ w_down # (seq_len, emb_dim) = (seq_len, 2048) - intermediates["down"] = down - - # 15. Residual connection - output = res1 + down - intermediates["output"] = output - - return output, intermediates - - -def forward(token_ids, weights, config, rope_lut=None): - """Full LLAMA-3.2-1B forward pass. - - Args: - token_ids: (seq_len,) integer array of token IDs. - weights: LlamaWeights containing all model parameters. - config: LlamaConfig with model hyperparameters. - rope_lut: Optional precomputed (seq_len, head_dim) RoPE LUT. - If None, one will be generated using generate_rope_lut. - - Returns: - logits: (seq_len, vocab_size) in F32. - """ - seq_len = len(token_ids) - - # Generate RoPE LUT if not provided - if rope_lut is None: - rope_lut = generate_rope_lut(config=config, seq_len=seq_len) - rope_lut = np.asarray(rope_lut, dtype=np.float32) - - # 1. Token embedding (CPU lookup) - embed_table = np.asarray(weights.embed_table, dtype=np.float32) - x = embed_table[token_ids] # (seq_len, emb_dim) - - # 2. Transformer blocks - for i in range(config.n_layers): - x, _ = transformer_block(x, weights.layers[i], rope_lut, config) - - # 3. Final RMS norm - x = rms_norm(x, weights.final_norm) - - # 4. Language model head (CPU GEMM) - lm_head = np.asarray(weights.lm_head, dtype=np.float32) - logits = x @ lm_head.T # (seq_len, vocab_size) - - return logits - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="CPU reference forward pass for LLAMA-3.2-1B" - ) - parser.add_argument( - "--model", - type=str, - default="meta-llama/Llama-3.2-1B", - help="HuggingFace model name or local path (default: meta-llama/Llama-3.2-1B)", - ) - parser.add_argument( - "--prompt", - type=str, - default="The capital of France is", - help="Input prompt (default: 'The capital of France is')", - ) - parser.add_argument( - "--seq-len", - type=int, - default=128, - help="Sequence length to pad/truncate to (default: 128)", - ) - parser.add_argument( - "--verify", - action="store_true", - help="Compare output against HuggingFace transformers reference", - ) - args = parser.parse_args() - - # Load weights - config = LlamaConfig() - print(f"Loading weights from {args.model}...") - weights = load_weights(args.model, config=config) - print(f" Config: {config}") - print( - f" Layers: {config.n_layers}, emb_dim: {config.emb_dim}, " - f"n_heads: {config.n_heads}, n_kv_heads: {config.n_kv_heads}, " - f"hidden_dim: {config.hidden_dim}, vocab_size: {config.vocab_size}" - ) - - # Tokenize - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(args.model) - token_ids = tokenizer.encode(args.prompt) - print(f"\nPrompt: '{args.prompt}'") - print(f"Token IDs ({len(token_ids)} tokens): {token_ids}") - - # Pad or truncate to seq_len - if len(token_ids) > args.seq_len: - token_ids = token_ids[: args.seq_len] - print(f"Truncated to {args.seq_len} tokens") - elif len(token_ids) < args.seq_len: - # Pad with EOS token (or 0 if no EOS) - pad_token = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 - original_len = len(token_ids) - token_ids = token_ids + [pad_token] * (args.seq_len - len(token_ids)) - print( - f"Padded from {original_len} to {args.seq_len} tokens " - f"(pad_token={pad_token})" - ) - - token_ids = np.array(token_ids, dtype=np.int64) - - # Run forward pass - print(f"\nRunning forward pass (seq_len={args.seq_len})...") - logits = forward(token_ids, weights, config) - print(f"Output logits shape: {logits.shape}") - - # Get the prediction at the last real token position - # (the position just before padding starts, or the last position if no padding) - prompt_len = len(tokenizer.encode(args.prompt)) - pred_pos = min(prompt_len - 1, args.seq_len - 1) - - # Top-5 predicted next tokens - next_token_logits = logits[pred_pos] - top5_indices = np.argsort(next_token_logits)[-5:][::-1] - top5_probs = softmax(next_token_logits) - - print(f"\nTop-5 predicted next tokens (position {pred_pos}):") - for rank, idx in enumerate(top5_indices): - token_str = tokenizer.decode([idx]) - prob = top5_probs[idx] - print( - f" {rank + 1}. '{token_str}' (id={idx}, logit={next_token_logits[idx]:.4f}, " - f"prob={prob:.4f})" - ) - - # Optional: verify against HuggingFace transformers - if args.verify: - print("\n--- Verification against HuggingFace transformers ---") - try: - import torch - from transformers import AutoModelForCausalLM - - print("Loading HuggingFace model...") - hf_model = AutoModelForCausalLM.from_pretrained( - args.model, torch_dtype=torch.float32 - ) - hf_model.eval() - - with torch.no_grad(): - input_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) - hf_output = hf_model(input_ids) - hf_logits = hf_output.logits[0].numpy() # (seq_len, vocab_size) - - print(f"HF logits shape: {hf_logits.shape}") - print(f"Our logits shape: {logits.shape}") - - # Compare at the prediction position - our_next = logits[pred_pos] - hf_next = hf_logits[pred_pos] - - # Absolute and relative error - abs_diff = np.abs(our_next - hf_next) - max_abs_err = np.max(abs_diff) - mean_abs_err = np.mean(abs_diff) - - # Relative error (avoid division by zero) - denom = np.maximum(np.abs(hf_next), 1e-8) - rel_diff = abs_diff / denom - max_rel_err = np.max(rel_diff) - mean_rel_err = np.mean(rel_diff) - - print(f"\nError at position {pred_pos}:") - print(f" Max absolute error: {max_abs_err:.6f}") - print(f" Mean absolute error: {mean_abs_err:.6f}") - print(f" Max relative error: {max_rel_err:.6f}") - print(f" Mean relative error: {mean_rel_err:.6f}") - - # Check if top-1 predictions match - our_top1 = np.argmax(our_next) - hf_top1 = np.argmax(hf_next) - match = our_top1 == hf_top1 - print(f"\nTop-1 prediction match: {'YES' if match else 'NO'}") - print(f" Ours: '{tokenizer.decode([our_top1])}' (id={our_top1})") - print(f" HF: '{tokenizer.decode([hf_top1])}' (id={hf_top1})") - - # Overall logits correlation - correlation = np.corrcoef(our_next, hf_next)[0, 1] - print(f" Logits correlation: {correlation:.8f}") - - if match and correlation > 0.999: - print("\nVERIFICATION PASSED") - else: - print("\nVERIFICATION FAILED") - - except ImportError as e: - print(f"Cannot verify: {e}") - print("Install torch and transformers: pip install torch transformers") diff --git a/programming_examples/llama32_1b/run_npu2_makefile_peano_synthetic_verify.lit b/programming_examples/llama32_1b/run_npu2_makefile_peano_synthetic_verify.lit deleted file mode 100644 index e85efda83..000000000 --- a/programming_examples/llama32_1b/run_npu2_makefile_peano_synthetic_verify.lit +++ /dev/null @@ -1,32 +0,0 @@ -// (c) Copyright 2026 Advanced Micro Devices, Inc. -// SPDX-License-Identifier: MIT -// -// REQUIRES: ryzen_ai_npu2, peano -// -// End-to-end LLAMA-3.2-1B prefill + 1 decode token with deterministic -// random weights (no HuggingFace download / no auth in CI). Compares the -// per-layer NPU output against a CPU F32 reference computed from the same -// synthetic weight tensors. We FileCheck the per-layer-internal -// correctness markers (q_roped / k_roped / final output) which are -// invariant to weight magnitude — the end-to-end K-cache drift after 16 -// layers is expected with unnormalized random weights and is not asserted -// here. -// -// RUN: mkdir -p test_synthetic_verify -// RUN: cd test_synthetic_verify -// RUN: make -f %S/Makefile clean PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR -// RUN: make -f %S/Makefile compile PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR -// RUN: make -f %S/Makefile verify WEIGHTS=synthetic N_TOKENS=1 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s -// -// Synthetic-weights banner. -// CHECK: Using synthetic random weights -// -// Per-layer kernel correctness — q_roped / k_roped / output all produced -// by the multi-launch ELFs and compared against the CPU F32 reference. -// CHECK: [OK] q_roped: {{.*}}corr=0.99 -// CHECK: [OK] k_roped: {{.*}}corr=0.99 -// CHECK: [OK] output: {{.*}}corr=0.99 -// -// Pipeline reaches end of prefill and emits at least one decode token. -// CHECK: NPU prefill done -// CHECK: Tokens/second diff --git a/programming_examples/llama32_1b/run_npu2_verify.lit b/programming_examples/llama32_1b/run_npu2_verify.lit new file mode 100644 index 000000000..dda8eba8d --- /dev/null +++ b/programming_examples/llama32_1b/run_npu2_verify.lit @@ -0,0 +1,17 @@ +// (c) Copyright 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// LLAMA-3.2-1B verify gate: top-k token-level inclusion check, NPU vs HF bf16, +// 8 prompts × 32 greedy tokens, k=5. Exercises the full production prefill + +// decode path through the verify subsystem (verify/verify_runner.py). +// +// Skips cleanly when HF_TOKEN is unset (gated model downloads require it). +// +// REQUIRES: ryzen_ai_npu2, peano, hf_token +// +// RUN: mkdir -p test_peano_verify +// RUN: cd test_peano_verify +// RUN: make -f %S/Makefile clean +// RUN: make -f %S/Makefile compile PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR +// RUN: make -f %S/Makefile verify PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: [verify] PASS diff --git a/programming_examples/llama32_1b/verify/.gitignore b/programming_examples/llama32_1b/verify/.gitignore new file mode 100644 index 000000000..d82687ff1 --- /dev/null +++ b/programming_examples/llama32_1b/verify/.gitignore @@ -0,0 +1,7 @@ +reports/ +__pycache__/ +*.pyc +# External-kernel objects spilled by compile_all_external_kernels into cwd +*.o +# Calibration backup file +thresholds.json.bak diff --git a/programming_examples/llama32_1b/verify/README.md b/programming_examples/llama32_1b/verify/README.md new file mode 100644 index 000000000..027c08a1c --- /dev/null +++ b/programming_examples/llama32_1b/verify/README.md @@ -0,0 +1,97 @@ +# Llama-3.2-1B verification + +Two ways to look at the production Llama-3.2-1B NPU2 inference pipeline, +both comparing against HuggingFace transformers in **bf16** (same dtype +as NPU — fair fight). Companion doc: `../docs/VERIFICATION.html`. + +Targets live in the parent Makefile (`programming_examples/llama32_1b/Makefile`): + +``` +cd programming_examples/llama32_1b + +make verify [MODEL=instruct|base] # ~4 min — top-k token-level correctness gate +make diagnosis [MODEL=...] [PROMPT="..."] # ~3 min — per-layer cosine, informational +make clean # rm build_*/ + verify/reports/ +``` + +## `make verify` — the correctness gate + +Top-k token-level inclusion check (mirrors vLLM's +`check_logprobs_close` in `tests/models/utils.py`). For each of 8 prompts: + +1. NPU and HF each greedy-decode 32 tokens, capturing top-5 token IDs per step. +2. Walk in lockstep. On the first step where chosen tokens differ, both + sides' chosen tokens must appear in the OTHER side's top-5; otherwise + FAIL. Stop walking after first divergence. +3. All 8 prompts must pass. `verify_runner.py` exits 1 on any FAIL, + exit 0 on PASS. + +This is the only correctness signal. The discrete top-k judgment is +robust to the bf16 ULP noise that fluctuates continuous metrics like +cosine, while still catching every real implementation regression. + +Configuration: +- **NPU FlashAttention is on** (`--npu-attn on` is the default) — verify + exercises the full NPU end-to-end production path: GEMV + RMSNorm + + RoPE + FlashAttention + LM-head GEMV. +- **Lite-mode runners**: skip per-layer intermediate capture, KV-cache + copies, and the CPU-side full-sequence LM-head recompute. Only the + per-step top-1 token + top-5 logits are read. +- **Tokenizer cached** via `functools.lru_cache` (no per-prompt reload). +- **MODEL=instruct** (default) uses `meta-llama/Llama-3.2-1B-Instruct` + with `prompts/instruct.txt` (instruction-style prompts). +- **MODEL=base** uses `meta-llama/Llama-3.2-1B` with `prompts/base.txt` + (continuation-style prompts matched to the base checkpoint's behavior). + +## `make diagnosis` — the inside-probing lens + +Reach for this when verify flags an issue and you need to localize. + +For one prompt, runs prefill on NPU + HF and reports per-position cosine ++ element-wise abs error for each layer's `ffn_out` (the block output). +Layers 0..n_layers-2 use each runner's raw layer output; the last layer +uses each runner's post-final-RMSNorm hidden state (HF exposes +`hidden_states[n_layers]` as post-norm by HF v5.3 convention; NPU +produces the equivalent via the final_norm step inside its production +LM-head GEMV path). + +**Diagnosis is informational only — it never fails the run.** The +verify gate is the correctness signal. The cosine table tells you where +the NPU implementation drifts most from HF (which layer, by how much), +which is what you want when triaging a real verify failure or weighing +a kernel-side optimization. Inspect the table by hand. + +Defaults to `--npu-attn on` so the inside-probing exercises the same +end-to-end NPU production path verify gates against. Diagnosis only +probes `ffn_out` (the block output), not `attn_out`, so the previous +runner-side per-layer attn_out reshape bug under `--npu-attn on` does +not affect this lens. + +## Output + +Each run writes a timestamped pair of files in `reports/`: + +- **verify**: `verify_topk_token_YYYYMMDD-HHMMSS.{json,md}` — Prompts table + + per-prompt top-k inclusion table with agreed-prefix sub-lines. +- **diagnosis**: `diagnosis_YYYYMMDD-HHMMSS.{json,md}` — single + per-layer cosine + max_abs table. + +`reports/` is gitignored. + +## Memory + +Real-weight runs need ~5 GB for the HF model + project numpy weights +shared by the NPU runner. Plan for ~6-8 GB working set. + +## File map + +| File | What | +|---|---| +| `verify_runner.py` | CLI orchestrator — picks `verify` vs `diagnosis` by `--prompts` | +| `comparators.py` | `compare_pair` (cosine + max_abs), `compute_topk_set_check` (top-k token-level), `topk_token_ids` | +| `report.py` | `Report` accumulator + JSON / markdown dumpers | +| `runners/npu_runner.py` | NPU production prefill + decode wrapper | +| `runners/hf_runner.py` | HuggingFace transformers bf16 wrapper | +| `runners/_records.py` | `PrefillRecord` / `DecodeStepRecord` dataclasses | +| `prompts/instruct.txt` | 8 instruction-style prompts (verify MODEL=instruct) | +| `prompts/base.txt` | 8 continuation-style prompts (verify MODEL=base) | diff --git a/programming_examples/llama32_1b/verify/__init__.py b/programming_examples/llama32_1b/verify/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/verify/comparators.py b/programming_examples/llama32_1b/verify/comparators.py new file mode 100644 index 000000000..22349d6af --- /dev/null +++ b/programming_examples/llama32_1b/verify/comparators.py @@ -0,0 +1,246 @@ +"""Numerical comparators for end-to-end verify. + +All metrics are pure numpy. Inputs may be bfloat16 or float32; we cast to +float32 internally. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from typing import Optional + +import numpy as np + + +def per_position_cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """Cosine similarity per position (per row). + + Reshape the inputs to (n_positions, feature_dim) by treating axis 0 as + the position axis and flattening all remaining axes. Returns a 1D array + of length n_positions, with NaN-safe handling: positions where either + side has zero norm return 0.0 (not NaN). + """ + a = np.asarray(a, dtype=np.float32) + b = np.asarray(b, dtype=np.float32) + if a.shape != b.shape: + raise ValueError(f"shape mismatch: {a.shape} vs {b.shape}") + n_pos = a.shape[0] + a2 = a.reshape(n_pos, -1) + b2 = b.reshape(n_pos, -1) + dot = np.sum(a2 * b2, axis=1) + na = np.linalg.norm(a2, axis=1) + nb = np.linalg.norm(b2, axis=1) + denom = na * nb + out = np.zeros(n_pos, dtype=np.float32) + mask = denom > 0 + out[mask] = dot[mask] / denom[mask] + return out + + +def aggregate(cosines: np.ndarray) -> dict: + """Aggregate per-position cosines into {min, p5, median, mean}.""" + arr = np.asarray(cosines, dtype=np.float32) + return { + "min": float(arr.min()), + "p5": float(np.percentile(arr, 5)), + "median": float(np.median(arr)), + "mean": float(arr.mean()), + } + + +def error_metrics(a: np.ndarray, b: np.ndarray) -> dict: + """Element-wise abs/rel error stats — diagnostic complement to cosine. + + cosine is direction-only and ignores magnitude (e.g. b = 2*a -> cos = 1). + abs/rel error catches the magnitude-side errors cosine misses. + """ + a = np.asarray(a, dtype=np.float32).flatten() + b = np.asarray(b, dtype=np.float32).flatten() + diff = np.abs(a - b) + denom = np.maximum(np.abs(b), 1e-6) + rel = diff / denom + return { + "max_abs": float(diff.max()), + "mean_abs": float(diff.mean()), + "max_rel": float(rel.max()), + "mean_rel": float(rel.mean()), + } + + +@dataclass +class ComparisonRecord: + """One per-layer probe result. Pure observation — diagnosis does not gate + on these (`make verify` is the gate). Threshold + status fields used to + live here and were retired with the threshold-based diagnosis design.""" + + name: str + pair: str # "npu_vs_hf" + layer: Optional[int] + cosine: dict # {min, p5, median, mean} + errors: dict # {max_abs, mean_abs, max_rel, mean_rel} + + def to_dict(self) -> dict: + return asdict(self) + + +def compare_pair( + name: str, npu: np.ndarray, hf: np.ndarray, layer: int | None +) -> ComparisonRecord: + """Compute per-position cosine + element-wise error for one NPU vs HF + layer probe. No threshold, no pass/fail — diagnosis is informational.""" + cos = per_position_cosine(npu, hf) + return ComparisonRecord( + name=name, + pair="npu_vs_hf", + layer=layer, + cosine=aggregate(cos), + errors=error_metrics(npu, hf), + ) + + +# --------------------------------------------------------------------------- +# Token-level top-k set inclusion check (the model-level correctness gate) +# --------------------------------------------------------------------------- +# +# Mirrors the logic of vLLM's tests/models/utils.py::check_logprobs_close. +# At each generation step: +# - If both runners chose the same token, skip (no check needed). +# - Otherwise: the first divergence is the only step we check. Each side's +# chosen token must appear in the OTHER side's top-k. If either fails, +# status is FAIL with a human-readable reason. If both succeed, status +# is OK — divergence is informational drift within the top-k band. +# After the first divergence we stop (vLLM does the same: once divergent, the +# downstream tokens are no longer apples-to-apples since each side is feeding +# its own chosen token into the next step). +# +# This is the discrete-judgment escape from continuous-metric ULP wars: bf16 +# noise can flip top-1 even between two implementations that are mathematically +# equivalent, but it almost never displaces a token out of the top-5. + + +def topk_token_ids(z: np.ndarray, k: int = 5) -> list[int]: + """Return the top-k token IDs from a 1D logit vector, highest first. + + Tie-breaking matches numpy.argmax: when two logits are exactly equal + (which happens routinely with bf16 inputs cast to F32, since adjacent + bf16 values land at the same F32 representation), the smaller token + ID wins. Without this, topk_token_ids[0] could disagree with + np.argmax(z) on the SAME array. + """ + z = np.asarray(z) + if z.ndim != 1: + raise ValueError(f"expected 1D logit vector, got shape {z.shape}") + if k > z.shape[0]: + raise ValueError(f"k={k} > vocab_size={z.shape[0]}") + idx = np.argpartition(-z, k - 1)[:k] + # lexsort: last key is primary. Primary = -z[idx] (largest z first); + # secondary = idx (smaller token-ID first as tiebreaker). + order = np.lexsort((idx, -z[idx])) + idx = idx[order] + return idx.tolist() + + +@dataclass +class TopKCheckRecord: + """Result of a single top-k token-level inclusion check on one prompt.""" + + prompt_idx: int + prompt_text: str # may be truncated for the report + n_steps: int + k: int + divergence_step: Optional[int] + test_chosen_at_div: Optional[int] + ref_chosen_at_div: Optional[int] + test_topk_at_div: Optional[list[int]] + ref_topk_at_div: Optional[list[int]] + status: str # "OK" | "FAIL" + fail_reason: Optional[str] + # 1-based rank of each side's chosen token within the OTHER side's top-k. + # None when the chosen token is not present (FAIL on that direction) or + # when there is no divergence at all. + test_chosen_rank_in_ref: Optional[int] = None + ref_chosen_rank_in_test: Optional[int] = None + # Decoded human-readable rendering (orchestrator populates via tokenizer). + test_chosen_text_at_div: Optional[str] = None + ref_chosen_text_at_div: Optional[str] = None + agreed_prefix_text: Optional[str] = None + + def to_dict(self) -> dict: + return asdict(self) + + +def compute_topk_set_check( + test_chosen: list[int], + test_topk: list[list[int]], + ref_chosen: list[int], + ref_topk: list[list[int]], + k: int = 5, + prompt_idx: int = 0, + prompt_text: str = "", +) -> TopKCheckRecord: + """Top-k token-level inclusion check on one prompt's generation sequence. + + Walk in lockstep. On the first chosen-token mismatch, both sides' chosen + tokens must appear in the OTHER side's top-k; otherwise FAIL. Stop after + the first divergence (mirrors vLLM's check_logprobs_close). All-match + returns OK with divergence_step=None. + """ + n = min(len(test_chosen), len(ref_chosen), len(test_topk), len(ref_topk)) + for i in range(n): + if test_chosen[i] == ref_chosen[i]: + continue + ref_top = list(ref_topk[i][:k]) + test_top = list(test_topk[i][:k]) + try: + test_rank: Optional[int] = ref_top.index(test_chosen[i]) + 1 + except ValueError: + test_rank = None + try: + ref_rank: Optional[int] = test_top.index(ref_chosen[i]) + 1 + except ValueError: + ref_rank = None + test_in_ref = test_rank is not None + ref_in_test = ref_rank is not None + if test_in_ref and ref_in_test: + status, reason = "OK", None + else: + parts = [] + if not test_in_ref: + parts.append( + f"test chose {test_chosen[i]} but it is not in ref top-{k} " + f"({ref_top})" + ) + if not ref_in_test: + parts.append( + f"ref chose {ref_chosen[i]} but it is not in test top-{k} " + f"({test_top})" + ) + status, reason = "FAIL", "; ".join(parts) + return TopKCheckRecord( + prompt_idx=prompt_idx, + prompt_text=prompt_text, + n_steps=n, + k=k, + divergence_step=i, + test_chosen_at_div=int(test_chosen[i]), + ref_chosen_at_div=int(ref_chosen[i]), + test_topk_at_div=[int(t) for t in test_top], + ref_topk_at_div=[int(t) for t in ref_top], + status=status, + fail_reason=reason, + test_chosen_rank_in_ref=test_rank, + ref_chosen_rank_in_test=ref_rank, + ) + return TopKCheckRecord( + prompt_idx=prompt_idx, + prompt_text=prompt_text, + n_steps=n, + k=k, + divergence_step=None, + test_chosen_at_div=None, + ref_chosen_at_div=None, + test_topk_at_div=None, + ref_topk_at_div=None, + status="OK", + fail_reason=None, + ) diff --git a/programming_examples/llama32_1b/verify/prompts/base.txt b/programming_examples/llama32_1b/verify/prompts/base.txt new file mode 100644 index 000000000..29e9fc91b --- /dev/null +++ b/programming_examples/llama32_1b/verify/prompts/base.txt @@ -0,0 +1,15 @@ +# Prompts used by `make verify MODEL=base` (Llama-3.2-1B base, no instruction +# tuning). Each prompt is intentionally an incomplete sentence — the base +# model continues raw text rather than answering instructions, so the +# topic is set up by leaving the model with a clear "next phrase". +# Topics deliberately mirror instruct.txt so base vs Instruct behavior +# can be compared on adjacent rows. +# One prompt per line. Lines starting with '#' or empty are ignored. +GPU stands for +The capital of France is +Artificial intelligence is a branch of computer science that +A neural network consists of +Once upon a time, there was a robot who dreamed about +The COVID-19 pandemic, which began in late 2019, +The Mona Lisa was painted by +The French translation of "The early bird catches the worm" is diff --git a/programming_examples/llama32_1b/verify/prompts/instruct.txt b/programming_examples/llama32_1b/verify/prompts/instruct.txt new file mode 100644 index 000000000..3e5ad25dc --- /dev/null +++ b/programming_examples/llama32_1b/verify/prompts/instruct.txt @@ -0,0 +1,13 @@ +# Prompts used by `make verify MODEL=instruct` (Llama-3.2-1B-Instruct). +# 7 prompts originally from vllm/tests/prompts/example.txt; prompt 0 +# swapped to "Introduce me what is GPU" (more relevant than the vLLM +# self-promo line for this project). +# One prompt per line. Lines starting with '#' or empty are ignored. +Introduce me what is GPU +Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. +Compare and contrast artificial intelligence with human intelligence in terms of processing information. +Describe the basic components of a neural network and how it can be trained. +Write a short story about a robot that dreams for the first time. +Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. +Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' diff --git a/programming_examples/llama32_1b/verify/report.py b/programming_examples/llama32_1b/verify/report.py new file mode 100644 index 000000000..8a1bcf208 --- /dev/null +++ b/programming_examples/llama32_1b/verify/report.py @@ -0,0 +1,182 @@ +"""Report accumulator + JSON / markdown dumpers. + +Two layouts produced from the same Report instance: + + `make verify` Top-k token-level inclusion gate. Records are added + via add_topk(pair, record); the markdown dumps a + Prompts table + per-pair top-k tables with agreed- + prefix sub-lines. has_failure() reflects the gate. + + `make diagnosis` Per-layer ffn_out cosine + max_abs (NPU vs HF bf16). + Records are added via add(record); the markdown + dumps one informational table with one row per + probed layer. Diagnosis never fails the run — + the verify gate is the only correctness signal. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Optional + +from comparators import ComparisonRecord, TopKCheckRecord + + +class Report: + def __init__(self, config: dict): + self.config: dict = dict(config) + self.records: list[ComparisonRecord] = [] + self.topk_checks: list[tuple[str, TopKCheckRecord]] = [] + self.prompts: list[str] = [] + + def add(self, record: ComparisonRecord) -> None: + self.records.append(record) + + def add_topk(self, pair: str, record: TopKCheckRecord) -> None: + self.topk_checks.append((pair, record)) + + def set_prompts(self, prompts: list[str]) -> None: + self.prompts = list(prompts) + + def summary(self) -> dict: + topk_passed = sum(1 for _, r in self.topk_checks if r.status == "OK") + topk_failed = sum(1 for _, r in self.topk_checks if r.status == "FAIL") + return { + "n_layer_records": len(self.records), + "topk_passed": topk_passed, + "topk_failed": topk_failed, + } + + def has_failure(self) -> bool: + # Only the verify-mode top-k gate signals failure. Diagnosis is + # informational; per-layer cosine numbers are inspected by humans, + # not gated. + for pair, rec in self.topk_checks: + if pair == "npu_vs_hf" and rec.status == "FAIL": + return True + return False + + def dump_json(self, path: str | Path) -> None: + topk_view: Optional[list[dict]] = None + if self.topk_checks: + topk_view = [ + {"pair": pair, **rec.to_dict()} for pair, rec in self.topk_checks + ] + data = { + "config": self.config, + "prompts": self.prompts or None, + "per_layer": [r.to_dict() for r in self.records], + "topk_checks": topk_view, + "summary": self.summary(), + } + Path(path).write_text(json.dumps(data, indent=2)) + + def dump_markdown(self, path: str | Path) -> None: + s = self.summary() + verdict = "FAIL" if self.has_failure() else "PASS" + lines: list[str] = [] + lines.append("# Verify report") + cfg_str = ", ".join(f"{k}={v}" for k, v in self.config.items()) + lines.append(f"\nConfig: {cfg_str}") + lines.append(f"\nResult: **{verdict}**") + if self.topk_checks: + lines.append( + f"\nTop-k token gate: {s['topk_passed']} PASS / " + f"{s['topk_failed']} FAIL " + f"(across {len(self.topk_checks)} prompt-pair checks)" + ) + if self.prompts: + lines.append("\n## Prompts\n") + lines.append("| # | Prompt |\n|--:|--------|") + for pi, p in enumerate(self.prompts): + cell = p.replace("|", "\\|").replace("\n", " ").replace("\r", " ") + lines.append(f"| {pi} | {cell} |") + + # ---- Diagnosis: per-layer ffn_out (NPU vs HF) ----------------------- + ffn_records = [r for r in self.records if r.name == "ffn_out"] + if ffn_records: + lines.append( + "\n## Per-layer hidden state (ffn_out, NPU vs HF bf16)\n" + "_Informational — diagnosis does not fail the run; " + "`make verify` is the gate._\n" + ) + lines.append("| Layer | cos_p5 | cos_min | cos_median | max_abs |") + lines.append("|------:|-------:|--------:|-----------:|--------:|") + for r in ffn_records: + lines.append( + f"| {r.layer} | {r.cosine['p5']:.6f} " + f"| {r.cosine['min']:.6f} | {r.cosine['median']:.6f} " + f"| {r.errors['max_abs']:.4g} |" + ) + + # ---- Verify: top-k inclusion (per-pair tables) ---------------------- + if self.topk_checks: + by_pair: dict[str, list] = {} + for pair, rec in self.topk_checks: + by_pair.setdefault(pair, []).append(rec) + + def _format_choice(text, token_id, rank): + """Render one side's chosen token as `"text" (#rank)` or `(✗)`.""" + label = text if text is not None else f"id={token_id}" + if rank is not None: + return f"{label} (#{rank})" + return f"{label} (✗)" + + for pair, recs in by_pair.items(): + pair_passed = sum(1 for r in recs if r.status == "OK") + pair_failed = sum(1 for r in recs if r.status == "FAIL") + k = recs[0].k if recs else "?" + test_side, ref_side = (s.upper() for s in pair.split("_vs_")) + lines.append( + f"\n## Top-k token inclusion — {pair} " + f"(k={k}, {pair_passed}/{len(recs)} PASS)\n" + ) + lines.append( + f"| # | Prompt | Steps | Diverge step " + f"| {test_side} choice (rank in {ref_side}) " + f"| {ref_side} choice (rank in {test_side}) | Status |" + ) + lines.append( + "|--:|--------|------:|-------------:" + "|---------|---------|:-------|" + ) + for r in recs: + if r.divergence_step is None: + div_cell = "—" + test_cell = "(all match)" + ref_cell = "(all match)" + else: + div_cell = str(r.divergence_step) + test_cell = _format_choice( + r.test_chosen_text_at_div, + r.test_chosen_at_div, + r.test_chosen_rank_in_ref, + ) + ref_cell = _format_choice( + r.ref_chosen_text_at_div, + r.ref_chosen_at_div, + r.ref_chosen_rank_in_test, + ) + prompt_cell = r.prompt_text.replace("|", "\\|") + lines.append( + f"| {r.prompt_idx} | {prompt_cell} | {r.n_steps} " + f"| {div_cell} | {test_cell} | {ref_cell} | {r.status} |" + ) + for r in recs: + if r.agreed_prefix_text and r.agreed_prefix_text != '""': + lines.append( + f"\n*Prompt {r.prompt_idx} agreed prefix " + f"(steps 0-{r.divergence_step - 1}):* " + f"{r.agreed_prefix_text}" + ) + for r in recs: + if r.fail_reason: + lines.append(f"\n*Prompt {r.prompt_idx} FAIL:* {r.fail_reason}") + if pair_failed: + lines.append( + f"\n_{pair_failed}/{len(recs)} prompts failed top-{k} " + "inclusion at first divergence._" + ) + + Path(path).write_text("\n".join(lines) + "\n") diff --git a/programming_examples/llama32_1b/verify/runners/__init__.py b/programming_examples/llama32_1b/verify/runners/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/programming_examples/llama32_1b/verify/runners/_records.py b/programming_examples/llama32_1b/verify/runners/_records.py new file mode 100644 index 000000000..7142a04fa --- /dev/null +++ b/programming_examples/llama32_1b/verify/runners/_records.py @@ -0,0 +1,33 @@ +"""Shared Record dataclasses returned by all Runner implementations.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class PrefillRecord: + layer_intermediates: list[dict[str, np.ndarray]] # len == n_layers + final_hidden: np.ndarray + # final_hidden after the model's final RMSNorm — the value that feeds + # into the LM-head matmul. HF transformers exposes this as + # output_hidden_states[n_layers] (which is post-final-norm by HF v5.3 + # convention; see hf_runner for the empirical confirmation). NPU + # produces it natively in non-lite mode (the same array used to + # compute final_logits). Diagnosis pairs this NPU vs HF cell as the + # "layer 15" probe so the last layer is not silently skipped. + final_hidden_normed: np.ndarray + logits_at_pred: np.ndarray + top1_token: int + + +@dataclass +class DecodeStepRecord: + step: int + current_pos: int + input_token: int + layer_intermediates: list[dict[str, np.ndarray]] + lm_head_logits: np.ndarray + top1_token: int diff --git a/programming_examples/llama32_1b/verify/runners/hf_runner.py b/programming_examples/llama32_1b/verify/runners/hf_runner.py new file mode 100644 index 000000000..c1ce429dd --- /dev/null +++ b/programming_examples/llama32_1b/verify/runners/hf_runner.py @@ -0,0 +1,134 @@ +"""HuggingFace transformers runner — bf16, runs on CPU. + +The single bf16 reference for both `make verify` and `make diagnosis`. +Two modes: + - lite_mode=True (used by `make verify`): pass output_hidden_states= + False so HF skips the per-layer hidden-state list internally; only + logits + top1 are read back. + - lite_mode=False (used by `make diagnosis`): collect per-layer + hidden_states. Per HF transformers v5.3 convention, hidden_states is + a tuple of length n_layers + 1: index 0 is the embedding output; + indices 1..n_layers-1 are the *raw* outputs of layers 0..n_layers-2; + index n_layers is the *post-final-norm* version of layer n_layers-1 + (the last layer's raw output is NOT exposed). We therefore expose + ffn_out for layers 0..n_layers-2 and ALSO surface hidden_states[-1] + as final_hidden_normed so the orchestrator can pair the L15 cell + with the NPU's own post-final-norm hidden state. + +All intermediates are cast to float32 NumPy before returning since NumPy +has no native bfloat16 and the comparators all operate in F32 space. +""" + +from __future__ import annotations + +import numpy as np +import torch +from transformers import AutoModelForCausalLM + +from runners._records import PrefillRecord, DecodeStepRecord + + +class HfRunner: + name = "hf_bf16" + + def __init__( + self, + model_name: str, + config, + max_seq: int, + lite_mode: bool = False, + ): + self.config = config + self.max_seq = max_seq + self.lite_mode = lite_mode + self.model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16 + ) + self.model.eval() + self.past_key_values = None + self._n_layers = config.n_layers + self._emb_dim = config.emb_dim + self._n_kv = config.n_kv_heads + self._head_dim = config.head_dim + + @torch.no_grad() + def prefill(self, prompt_tokens: np.ndarray) -> PrefillRecord: + input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0) + out = self.model( + input_ids, + output_hidden_states=not self.lite_mode, + use_cache=True, + return_dict=True, + ) + logits = out.logits[0, -1].cpu().float().numpy() # (vocab,) + top1 = int(np.argmax(logits)) + self.past_key_values = out.past_key_values + if self.lite_mode: + empty = np.empty((0,), dtype=np.float32) + return PrefillRecord( + layer_intermediates=[], + final_hidden=empty, + final_hidden_normed=empty, + logits_at_pred=logits, + top1_token=top1, + ) + hidden_states = out.hidden_states + layer_intermediates: list[dict[str, np.ndarray]] = [] + for li in range(self._n_layers - 1): + # .float() upcasts bf16 to f32 — NumPy has no native bf16. + ffn_out = hidden_states[li + 1][0].cpu().float().numpy() + layer_intermediates.append({"ffn_out": ffn_out}) + # Last-layer entry intentionally has no ffn_out — the orchestrator + # uses final_hidden_normed for the L15 probe instead. + layer_intermediates.append({}) + # hidden_states[-1] is the post-final-norm version of the last + # layer's output (HF v5.3 convention). Same value the model fed + # into lm_head. Empirically: for raw last-layer hidden of magnitude + # ~130, max|raw + final_norm - hs[-1]| ~ 1e-2. + final_hidden_normed = hidden_states[-1][0].cpu().float().numpy() + return PrefillRecord( + layer_intermediates=layer_intermediates, + final_hidden=final_hidden_normed, # legacy field; same value here + final_hidden_normed=final_hidden_normed, + logits_at_pred=logits, + top1_token=top1, + ) + + @torch.no_grad() + def decode_step(self, input_token: int, current_pos: int) -> DecodeStepRecord: + if self.past_key_values is None: + raise RuntimeError("decode_step called before prefill") + input_ids = torch.tensor([[input_token]], dtype=torch.long) + out = self.model( + input_ids, + past_key_values=self.past_key_values, + output_hidden_states=False, # decode probes are not collected + use_cache=True, + return_dict=True, + ) + logits = out.logits[0, -1].cpu().float().numpy() + top1 = int(np.argmax(logits)) + self.past_key_values = out.past_key_values + return DecodeStepRecord( + step=current_pos, + current_pos=current_pos, + input_token=input_token, + layer_intermediates=[], + lm_head_logits=logits, + top1_token=top1, + ) + + @torch.no_grad() + def free_run_decode(self, prompt_tokens: np.ndarray, n_tokens: int) -> list[int]: + # Reset cache for an isolated free run. + self.past_key_values = None + prefill_rec = self.prefill(prompt_tokens) + out_tokens = [prefill_rec.top1_token] + cur = len(prompt_tokens) + next_token = prefill_rec.top1_token + for _ in range(n_tokens): + rec = self.decode_step(input_token=next_token, current_pos=cur) + out_tokens.append(rec.top1_token) + cur += 1 + next_token = rec.top1_token + return out_tokens diff --git a/programming_examples/llama32_1b/verify/runners/npu_runner.py b/programming_examples/llama32_1b/verify/runners/npu_runner.py new file mode 100644 index 000000000..db0feecc8 --- /dev/null +++ b/programming_examples/llama32_1b/verify/runners/npu_runner.py @@ -0,0 +1,197 @@ +"""NPU runner — thin adapter over the production prefill / decode functions. + +Delegates the actual work to: + - llama32_1b_inference.prepare_runtime (runtime setup) + - llama32_1b_inference.run_npu_prefill (prefill + KV cache extract + LM head) + - llama32_1b_inference.run_npu_decode_step (one decode step + LM head) + - llama32_1b_prefill.compile_all_kernels / decode.compile_decode_kernels + +The runner holds the stateful pieces (kernel caches + KV cache) across calls; +the actual NPU compute path is identical to what `make run` exercises. Any +change to the production functions is automatically picked up by `make verify`. + +Two modes: + - lite_mode=True (used by `make verify`): prefill returns logits + chosen + token only; layer_intermediates is left empty. + - lite_mode=False (used by `make diagnosis`): also collects per-layer + ffn_out + the post-final-norm hidden state for the L15 probe. The + layer-intermediate collection runs OUTSIDE the production path — it + re-invokes run_transformer_block layer-by-layer with the same inputs, + capturing the dict each block returns. This is a diagnosis-only side + channel; verify never touches it. +""" + +from __future__ import annotations + +import numpy as np +from ml_dtypes import bfloat16 + +from kernel_builder.cache import KernelCache +from llama32_1b_prefill import ( + compile_all_kernels as compile_prefill_kernels, + run_transformer_block as run_prefill_block, +) +from llama32_1b_decode import compile_decode_kernels +from llama32_1b_inference import ( + prepare_runtime, + run_npu_prefill, + run_npu_decode_step, +) +from llama32_1b_weights import generate_rope_lut +from llama32_1b_cpu_helpers import rms_norm + +from runners._records import PrefillRecord, DecodeStepRecord + + +class NpuRunner: + name = "npu" + + def __init__( + self, + weights, + config, + max_seq: int, + tokenizer, + npu_attn: bool = True, + lite_mode: bool = False, + ): + self.weights = weights + self.config = config + self.max_seq = max_seq + self.npu_attn = npu_attn + self.cpu_attn = not npu_attn + self.lite_mode = lite_mode + # tokenizer is needed only to give run_npu_prefill an EOS-token-id + # for padding the (raw) prompt to max_seq. Verify orchestrator passes + # the same tokenizer it uses to encode prompts, so pad-token ID + # matches the prompt's tokenization. + self._tokenizer = tokenizer + + self.rope_lut_bf16 = generate_rope_lut(config=config, seq_len=max_seq).astype( + bfloat16 + ) + + # Compile prefill + decode kernels (same ones production compiles). + self.prefill_cache = KernelCache(verbose=False) + compile_prefill_kernels( + self.prefill_cache, + config, + seq_len=max_seq, + cpu_attn=self.cpu_attn, + ) + self.decode_cache = KernelCache(verbose=False) + compile_decode_kernels(self.decode_cache, config) + + # Production prepare_runtime: weight pre-transpose, per-layer index + # tagging, BO preloading. + prepare_runtime( + self.prefill_cache, + self.decode_cache, + weights, + config, + max_seq, + self.rope_lut_bf16, + ) + + # KV cache state lives across decode_step calls within one prefill. + # prefill() repopulates this from run_npu_prefill's return. + self.k_cache = None + self.v_cache = None + + def prefill(self, prompt_tokens: np.ndarray) -> PrefillRecord: + # Production-side run_once pre-pads the prompt to the kernel's + # compiled seq_len (= self.max_seq) with eos_token_id before calling + # run_npu_prefill. Mirror that here so the verify path hits exactly + # the same code with exactly the same shape. + eos = self._tokenizer.eos_token_id + if len(prompt_tokens) < self.max_seq: + padded = list(prompt_tokens) + [eos] * (self.max_seq - len(prompt_tokens)) + else: + padded = list(prompt_tokens)[: self.max_seq] + # Production path — exact same code make run uses. + prefill_token, logits_row, k_cache, v_cache, prompt_len = run_npu_prefill( + padded, + self.weights, + self.config, + self.prefill_cache, + self.decode_cache, + self.rope_lut_bf16, + self.max_seq, + tokenizer=self._tokenizer, + cpu_attn=self.cpu_attn, + profile=False, + quiet=True, + ) + # Persist KV cache for subsequent decode_step calls in this run. + self.k_cache = k_cache + self.v_cache = v_cache + + if self.lite_mode: + empty = np.empty((0,), dtype=np.float32) + return PrefillRecord( + layer_intermediates=[], + final_hidden=empty, + final_hidden_normed=empty, + logits_at_pred=logits_row, + top1_token=prefill_token, + ) + + # ---- Diagnosis-only side channel: re-run the prefill layer loop + # to capture per-layer ffn_out + the post-final-norm hidden state. + # This is duplicate compute (~3-5 s extra) but only happens in + # diagnosis mode, which is single-prompt by design. + cfg = self.config + if len(prompt_tokens) < self.max_seq: + pad = np.zeros(self.max_seq - len(prompt_tokens), dtype=prompt_tokens.dtype) + padded = np.concatenate([prompt_tokens, pad]) + else: + padded = prompt_tokens[: self.max_seq] + embed = self.weights.embed_table[padded].astype(np.float32) + x = embed.astype(bfloat16) + layer_intermediates: list[dict[str, np.ndarray]] = [] + for li in range(cfg.n_layers): + x, ints = run_prefill_block( + x, + self.weights.layers[li], + self.rope_lut_bf16, + cfg, + self.prefill_cache, + layer_idx=li, + cpu_attn=self.cpu_attn, + verbose=False, + ) + fo_full = np.asarray(ints["ffn_out"]) + layer_intermediates.append({"ffn_out": fo_full[:prompt_len]}) + + # Post-final-norm hidden — the value the LM-head GEMV sees. + x_full_f32 = np.asarray(x, dtype=np.float32)[:prompt_len] + x_full_normed = rms_norm(x_full_f32, self.weights.final_norm) + + return PrefillRecord( + layer_intermediates=layer_intermediates, + final_hidden=x_full_f32, + final_hidden_normed=x_full_normed.astype(np.float32), + logits_at_pred=logits_row, + top1_token=prefill_token, + ) + + def decode_step(self, input_token: int, current_pos: int) -> DecodeStepRecord: + x = self.weights.embed_table[input_token].astype(bfloat16) + next_token, logits = run_npu_decode_step( + x, + self.weights, + self.config, + self.decode_cache, + self.rope_lut_bf16, + self.k_cache, + self.v_cache, + current_pos, + ) + return DecodeStepRecord( + step=current_pos, + current_pos=current_pos, + input_token=input_token, + layer_intermediates=[], + lm_head_logits=logits, + top1_token=next_token, + ) diff --git a/programming_examples/llama32_1b/verify/verify_runner.py b/programming_examples/llama32_1b/verify/verify_runner.py new file mode 100644 index 000000000..c5a42a347 --- /dev/null +++ b/programming_examples/llama32_1b/verify/verify_runner.py @@ -0,0 +1,373 @@ +"""verify_runner.py — orchestrate the verify gate and the diagnosis lens. + +Two modes selected by --prompts: + + --prompts topk_token `make verify` token-level top-k inclusion gate. + NPU + HF bf16 only, lite mode + runners, 8 prompts × 32 greedy + tokens, top-5 set inclusion. + Method mirrors vLLM's + check_logprobs_close. ~4 min/run. + + --prompts single `make diagnosis` inside-probing microscope. NPU + HF + bf16 only, full-capture runners, + one prompt's prefill, per-layer + ffn_out cosine + max_abs (NPU vs + HF) for layers 0..n_layers-2 plus + the post-final-norm hidden as the + L15 cell. No decode loop, no + logits gate, no token match — + `verify` already checks the + user-visible output. +""" + +from __future__ import annotations + +import argparse +import functools +import sys +from datetime import datetime +from pathlib import Path +from typing import Optional + +import numpy as np + +# Ensure project + verify dirs are importable. +HERE = Path(__file__).parent +PROJECT = HERE.parent +sys.path.insert(0, str(PROJECT)) +sys.path.insert(0, str(HERE)) + +from comparators import ( + compare_pair, + compute_topk_set_check, + topk_token_ids, +) +from report import Report +from runners.npu_runner import NpuRunner + +DEFAULT_PROMPT = "The capital of France is" + +# Same architecture (16 layers, emb=2048, n_heads=32, n_kv_heads=8, +# head_dim=64, vocab=128256) — only the weight tensors and tokenizer +# differ. base = original pretraining checkpoint (text continuation); +# instruct = what vLLM and other production stacks deploy. +MODEL_CHOICES = { + "base": "meta-llama/Llama-3.2-1B", + "instruct": "meta-llama/Llama-3.2-1B-Instruct", +} +BLOCK_PROBE = "ffn_out" + +# Token-level top-k inclusion gate constants. Values mirror vLLM's +# check_logprobs_close defaults (max_tokens=32, num_logprobs=5). +PROMPTS_DIR = HERE / "prompts" +DEFAULT_PROMPTS_FILE = { + "base": PROMPTS_DIR / "base.txt", + "instruct": PROMPTS_DIR / "instruct.txt", +} +GATE_N_TOKENS = 32 # greedy tokens decoded per prompt +GATE_K = 5 # top-k inclusion threshold + + +def _load_weights(weights_mode: str, config, seed: int, model_name: str): + from llama32_1b_weights import synthetic_weights, load_weights + + if weights_mode == "synthetic": + return synthetic_weights(config, seed=seed) + return load_weights(model_name, config=config) + + +@functools.lru_cache(maxsize=4) +def _get_tokenizer(model_name: str): + """Cached tokenizer loader. AutoTokenizer.from_pretrained is ~50 ms even + when the files are local — pre-cache, we paid that 8 times per verify run.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(model_name) + + +def _tokenize(prompt: str, model_name: str): + tok = _get_tokenizer(model_name) + ids = tok.encode(prompt) + return np.array(ids, dtype=np.int64), tok + + +def _load_prompts(path: Path) -> list[str]: + """Load prompts from a file; skip blank and '#' comment lines.""" + out: list[str] = [] + for line in path.read_text().splitlines(): + line = line.strip() + if line and not line.startswith("#"): + out.append(line) + return out + + +def _decode_token_for_display(tokenizer, token_id: Optional[int]) -> Optional[str]: + """Render one token ID as a quoted, escape-safe string for the report. + Quoting keeps leading whitespace visible (most LLM tokens carry one).""" + if token_id is None: + return None + text = tokenizer.decode([int(token_id)]) + text = text.replace("\\", "\\\\").replace("|", "\\|") + text = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") + return f'"{text}"' + + +def _generate_with_topk(runner, prompt_tokens: np.ndarray, n_tokens: int, k: int): + """Free-run greedy decode capturing chosen token + top-k token IDs per step. + + Returns (chosen_tokens, topk_per_step) — both length n_tokens. The first + entry is the prefill prediction; subsequent entries are decode-step + predictions, each fed as input to the next step. + + Sanity check: each step's chosen token MUST equal the first entry of + that step's top-k. If it does not, one of the runner's logit fields has + been mutated between top1_token computation and the field being read + here — print a loud warning so the rendered report is not misinterpreted + as a real model disagreement. + """ + + def _check(step_idx, chosen_id, topk_ids, tag): + if topk_ids and chosen_id != topk_ids[0]: + print( + f"[verify] WARN: {tag} step {step_idx} top1_token={chosen_id} " + f"!= topk[0]={topk_ids[0]} (full top-{k}={topk_ids}). " + "Indicates runner-side logit mutation between top1_token " + "and lm_head_logits/logits_at_pred capture.", + file=sys.stderr, + ) + + runner_tag = getattr(runner, "name", type(runner).__name__) + pf = runner.prefill(prompt_tokens) + chosen = [pf.top1_token] + topk = [topk_token_ids(np.asarray(pf.logits_at_pred), k)] + _check(0, pf.top1_token, topk[0], runner_tag) + cur = len(prompt_tokens) + next_tok = pf.top1_token + for step_i in range(1, n_tokens): + ds = runner.decode_step(next_tok, cur) + chosen.append(ds.top1_token) + step_topk = topk_token_ids(np.asarray(ds.lm_head_logits), k) + topk.append(step_topk) + _check(step_i, ds.top1_token, step_topk, runner_tag) + cur += 1 + next_tok = ds.top1_token + return chosen, topk + + +def _run_diagnosis(npu, hf, prompt_tokens, report, n_layers): + """Diagnosis lens: per-layer ffn_out (NPU vs HF bf16) for one prompt. + + For layers 0..n_layers-2 we compare each runner's raw layer output + (npu.layer_intermediates[li]['ffn_out'] vs hf.layer_intermediates[li] + ['ffn_out']). For the last layer we compare each runner's + final_hidden_normed (the post-final-RMSNorm hidden state that feeds + LM-head) — HF's hidden_states[n_layers] is post-norm by HF v5.3 + convention, and NPU exposes the equivalent via the same final_norm + application it does inside the production LM-head GEMV path. + + Diagnosis is informational only — no thresholds, no pass/fail. Inspect + the cosine table by hand; the verify gate is the actual correctness + signal. + """ + print("[diagnosis] prefill: NPU + HF...") + npu_pf = npu.prefill(prompt_tokens) + hf_pf = hf.prefill(prompt_tokens) + print("[diagnosis] comparing per-layer ffn_out (NPU vs HF bf16)...") + for li in range(n_layers - 1): + report.add( + compare_pair( + name=BLOCK_PROBE, + npu=npu_pf.layer_intermediates[li][BLOCK_PROBE], + hf=hf_pf.layer_intermediates[li][BLOCK_PROBE], + layer=li, + ) + ) + report.add( + compare_pair( + name=BLOCK_PROBE, + npu=npu_pf.final_hidden_normed, + hf=hf_pf.final_hidden_normed, + layer=n_layers - 1, + ) + ) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--npu-attn", choices=["on", "off"], default="on") + p.add_argument("--prompt", default=DEFAULT_PROMPT) + p.add_argument("--weights", choices=["hf", "synthetic"], default="hf") + p.add_argument( + "--model", + choices=list(MODEL_CHOICES), + default="instruct", + help="Llama-3.2-1B checkpoint. Default 'instruct' matches what " + "production stacks deploy. 'base' is the original pretraining " + "checkpoint (text continuation).", + ) + p.add_argument("--report-dir", default=str(HERE / "reports")) + p.add_argument( + "--no-strict", + action="store_true", + help="Disable hard exit on FAIL (default: exit 1 on FAIL)", + ) + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompts", + choices=["single", "topk_token"], + default="single", + help="'single' (used by `make diagnosis`) probes per-layer ffn_out " + "for one prompt. 'topk_token' (used by `make verify`) runs the " + "8-prompt top-k token-level inclusion gate. The two modes are " + "exclusive.", + ) + p.add_argument( + "--prompts-file", + default=None, + help="Override the prompt file used by --prompts topk_token. " + "Defaults to verify/prompts/{model}.txt.", + ) + args = p.parse_args() + + from llama32_1b_weights import LlamaConfig + + config = LlamaConfig() + model_name = MODEL_CHOICES[args.model] + weights = _load_weights(args.weights, config, args.seed, model_name) + # Production prefill kernels are tiled for seq_len=2048; NpuRunner pads + # short prompts internally. + max_seq = 2048 + + in_verify_mode = args.prompts == "topk_token" + report = Report( + config={ + "mode": "verify" if in_verify_mode else "diagnosis", + "weights": args.weights, + "model": args.model, + "model_name": model_name, + "npu_attn": args.npu_attn == "on", + "prompt": args.prompt if not in_verify_mode else None, + } + ) + + # ---- Build runners ---- + # Both modes use NPU + HF bf16 only. Verify runs lite (no per-layer + # capture); diagnosis runs full-capture for the per-layer probe. + lite = in_verify_mode + print(f"[verify] mode = {report.config['mode']}, lite={lite}") + print("[verify] building NPU runner...") + npu = NpuRunner( + weights, + config, + max_seq=max_seq, + tokenizer=_get_tokenizer(model_name), + npu_attn=(args.npu_attn == "on"), + lite_mode=lite, + ) + from runners.hf_runner import HfRunner + + print(f"[verify] building HF runner ({model_name}, lite={lite}, may download)...") + try: + hf = HfRunner( + model_name=model_name, + config=config, + max_seq=max_seq, + lite_mode=lite, + ) + except Exception as e: + print(f"[verify] HF runner unavailable: {e}", file=sys.stderr) + sys.exit(1) + + # ---- Diagnosis path: single prompt, per-layer ffn_out only ---- + if not in_verify_mode: + prompt_tokens, _ = _tokenize(args.prompt, model_name) + _run_diagnosis(npu, hf, prompt_tokens, report, config.n_layers) + Path(args.report_dir).mkdir(parents=True, exist_ok=True) + stamp = datetime.now().strftime("%Y%m%d-%H%M%S") + json_path = Path(args.report_dir) / f"diagnosis_{stamp}.json" + md_path = Path(args.report_dir) / f"diagnosis_{stamp}.md" + report.dump_json(json_path) + report.dump_markdown(md_path) + print(f"\n[verify] Report: {md_path}") + print(f"[verify] JSON: {json_path}") + print(f"[verify] Summary: {report.summary()}") + if report.has_failure() and not args.no_strict: + print("[verify] FAIL — see report for details.", file=sys.stderr) + sys.exit(1) + print("[verify] PASS") + return + + # ---- Verify path: 8-prompt top-k token-level inclusion gate ---- + prompts_path = ( + Path(args.prompts_file) + if args.prompts_file + else DEFAULT_PROMPTS_FILE[args.model] + ) + prompts = _load_prompts(prompts_path) + report.set_prompts(prompts) + report.config["prompts_file"] = str(prompts_path) + print( + f"[verify] top-k token gate: {len(prompts)} prompts × " + f"{GATE_N_TOKENS} tokens, k={GATE_K} (from {prompts_path.name})" + ) + for pi, prompt in enumerate(prompts): + short = (prompt[:60] + "…") if len(prompt) > 60 else prompt + print(f"[verify] prompt {pi + 1}/{len(prompts)}: {short!r}") + ptoks, tokenizer = _tokenize(prompt, model_name) + print(f"[verify] NPU greedy decode ({GATE_N_TOKENS} tokens)...") + npu_chosen, npu_topk = _generate_with_topk(npu, ptoks, GATE_N_TOKENS, GATE_K) + print(f"[verify] HF greedy decode ({GATE_N_TOKENS} tokens)...") + hf_chosen, hf_topk = _generate_with_topk(hf, ptoks, GATE_N_TOKENS, GATE_K) + + def _decorate(rec, test_seq): + """Inject decoded text into the record: + - the two chosen tokens at divergence (with rank context) + - the agreed prefix (the tokens both runners produced + identically before divergence) — empty string when + divergence_step == 0. + """ + rec.test_chosen_text_at_div = _decode_token_for_display( + tokenizer, rec.test_chosen_at_div + ) + rec.ref_chosen_text_at_div = _decode_token_for_display( + tokenizer, rec.ref_chosen_at_div + ) + if rec.divergence_step is not None and rec.divergence_step > 0: + prefix_ids = [int(t) for t in test_seq[: rec.divergence_step]] + raw = tokenizer.decode(prefix_ids) + raw = raw.replace("\\", "\\\\").replace("|", "\\|") + raw = raw.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") + rec.agreed_prefix_text = f'"{raw}"' + elif rec.divergence_step == 0: + rec.agreed_prefix_text = '""' + return rec + + rec = compute_topk_set_check( + test_chosen=npu_chosen, + test_topk=npu_topk, + ref_chosen=hf_chosen, + ref_topk=hf_topk, + k=GATE_K, + prompt_idx=pi, + prompt_text=short, + ) + report.add_topk(pair="npu_vs_hf", record=_decorate(rec, npu_chosen)) + + Path(args.report_dir).mkdir(parents=True, exist_ok=True) + stamp = datetime.now().strftime("%Y%m%d-%H%M%S") + json_path = Path(args.report_dir) / f"verify_topk_token_{stamp}.json" + md_path = Path(args.report_dir) / f"verify_topk_token_{stamp}.md" + report.dump_json(json_path) + report.dump_markdown(md_path) + print(f"\n[verify] Report: {md_path}") + print(f"[verify] JSON: {json_path}") + print(f"[verify] Summary: {report.summary()}") + if report.has_failure() and not args.no_strict: + print("[verify] FAIL — see report for details.", file=sys.stderr) + sys.exit(1) + print("[verify] PASS") + + +if __name__ == "__main__": + main()