[llama32_1b] 3-launch o_gemv_ffn for decode (+17% tok/s)#1631
Merged
erwei-xilinx merged 3 commits intoMay 30, 2026
Conversation
b98ea43 to
c04003e
Compare
Replaces the 8-launch o_gemv_ffn decode kernel with a 3-launch design
that stitches three primitives in one ELF and routes the post-attention
residual through a row-0 subview of a packed 2D arg so a single
NPU-computed value feeds two downstream consumers without a host copy
or an intermediate launch.
Stages:
1. matvec_2tile_add: res1 = wo @ attn_out + x_residual,
written into arg6[0]
2. matvec_swiglu_rms: swiglu = silu(gate @ rms_norm(arg6)) * up,
with gate/up rows interleaved into arg7
3. matvec_2tile_add: output = wdown @ swiglu + res1,
re-reading res1 from arg6[0]
End-to-end Llama-3.2-1B decode on NPU2: 91 ms/tok → 78 ms/tok
(10.93 → 12.78 tok/s, +17%). Generated text matches the baseline
(same response, same EOT stop).
Files
Two new generic primitives, each with __main__ + Makefile + lit tests:
- matrix_vector_multiplication/bf16_cascade/matvec_2tile_add.py
Two-tile-per-col matvec + residual add via intra-col npu_cascade
(matvec_herd north, add herd south).
- decode_ffn_swiglu/matvec_swiglu_rms.py
GEMV with weighted RMSNorm input (packed [2, K]) and fused SwiGLU
output (M/2 elements) over interleaved gate/up rows.
Llama integration:
- llama32_1b/multi_launch_builder/o_gemv_ffn_multi.py
Replaces the 8-launch builder. Three sub-launches stitched with
arg6[0] subview routing; 15-arg ABI kept for caller compatibility
(dead args become zero placeholders).
- llama32_1b/llama32_1b_decode.py + llama32_1b_inference.py
Decode call site uses the new arg6 packed RMS input + arg7
interleaved w_gateup. Preload builds the interleaved w_gateup once
per layer and frees the original wgate/wup arrays afterward (saves
~1 GB host RAM across 16 layers).
- llama32_1b/kernel_builder/external_kernels.py
Drops compile_mv_k8192 (8-launch artifact); adds compile_mv_bf16
for the new matvec_2tile_add micro-kernel.
- llama32_1b/kernel_builder/stitching.py
Adds _extract_channel_decls, #set affine alias handling in extract
+ rename, and arg_aliases parameter on _fix_launch_func_args so a
sub-launch's func arg can resolve to an arbitrary SSA value in the
combined func body (used here for the arg6[0] subview).
Compiler dependency
The new ELF emits a rank-reducing 2D→1D memref.subview at the top of
@o_gemv_ffn that is rejected by the current pinned mlir-aie at
aie.dma_bd lowering. This PR is gated on mlir-aie #3121 (generalizes
traceSubviewToBlockArgument to N-D rank-reducing subviews) being
merged. utils/clone-mlir-aie.sh HASH will be bumped in a follow-up
commit once a mlir-aie tag including #3121 exists.
Validation
- matvec_2tile_add lit tests (M=K=2048 and M=2048, K=8192):
correlation 0.999994 / 0.999991 on NPU2.
- matvec_swiglu_rms lit test (M=16384, K=2048): PASS on NPU2.
- Full llama32_1b inference (`make compile` + `make run`) on real
Llama-3.2-1B-Instruct: prompt "What is the capital of France?"
returns "The capital of France is Paris." at 12.78 tok/s.
Documentation under llama32_1b/docs/ still references 8-launch design
details (mv_k8192.o, "8 launches"); a docs follow-up is appropriate
once this lands.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR replaces the Llama 3.2 1B decode o_gemv_ffn path with a stitched 3-launch design using new BF16 matvec/add and fused RMSNorm+SwiGLU primitives.
Changes:
- Adds new generic NPU2/Peano primitives and lit coverage for
matvec_2tile_addandmatvec_swiglu_rms. - Reworks
o_gemv_ffn_multi.pyto stitch the three stages and route the residual througharg6[0]. - Updates decode preload/runtime plumbing to use packed RMS input and interleaved gate/up weights.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
programming_examples/matrix_vector_multiplication/bf16_cascade/run_2tile_add_npu2_2048x8192_peano.lit |
Adds large-K lit coverage for the 2-tile matvec+add primitive. |
programming_examples/matrix_vector_multiplication/bf16_cascade/run_2tile_add_npu2_2048x2048_peano.lit |
Adds base-shape lit coverage for the 2-tile matvec+add primitive. |
programming_examples/matrix_vector_multiplication/bf16_cascade/mv_bf16.cc |
Adds BF16 micro-kernels for zeroing, matvec accumulation, and residual add. |
programming_examples/matrix_vector_multiplication/bf16_cascade/matvec_2tile_add.py |
Adds AIR builder and standalone runner for matvec plus residual add. |
programming_examples/matrix_vector_multiplication/bf16_cascade/Makefile |
Adds build/run targets for the new 2-tile matvec+add primitive. |
programming_examples/llama32_1b/multi_launch_builder/o_gemv_ffn_multi.py |
Replaces the 8-launch decode FFN stitcher with the new 3-stage design. |
programming_examples/llama32_1b/llama32_1b_inference.py |
Prepares packed RMS buffers and interleaved gate/up weights for decode. |
programming_examples/llama32_1b/llama32_1b_decode.py |
Updates decode invocation to match the new 15-arg packed ABI. |
programming_examples/llama32_1b/kernel_builder/stitching.py |
Extends stitching helpers for channel decls, affine sets, and SSA aliases. |
programming_examples/llama32_1b/kernel_builder/external_kernels.py |
Replaces the old K8192 GEMV object build with mv_bf16.o. |
programming_examples/decode_ffn_swiglu/run_npu2_peano.lit |
Adds lit coverage for fused RMSNorm+SwiGLU GEMV. |
programming_examples/decode_ffn_swiglu/matvec_swiglu_rms.py |
Adds fused weighted RMSNorm, interleaved gate/up GEMV, and SwiGLU output primitive. |
programming_examples/decode_ffn_swiglu/Makefile |
Adds build/run targets for the fused decode FFN primitive. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- cache.py: stage mv_bf16.o (used by o_gemv_ffn stages 1+3 via matvec_2tile_add); drop stale mv_k8192.o. - matvec_swiglu_rms.py: deallocate l1_rms_in_data alongside the other per-segment L1 buffers. - bf16_cascade/Makefile: forward TILE_M_2T / K_CHUNK_2T as -DDIM_M / -DDIM_K to mv_bf16.cc so the linked micro-kernel matches the MLIR-side tile shape under overrides. Verified: llama32_1b decode "Paris" answer at 12.83 tok/s; swiglu_rms lit PASS; matvec_2tile_add lit correlation 0.999994 PASS. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces the 8-launch
o_gemv_ffndecode kernel with a 3-launch design that stitches three primitives in one ELF and routes the post-attention residual through a row-0 subview of a packed 2D arg.End-to-end on NPU2 with real Llama-3.2-1B-Instruct:
Design
Three sub-launches inside one ELF, sharing core columns and sequenced:
`arg6` is a packed `memref<2 × emb_dim × bf16>`: row 0 = res1 (NPU-computed by stage 1), row 1 = ffn_norm_w (host pre-loaded once). Stage 1's D-output and stage 3's R-input both reference a row-0 subview of arg6, so the same NPU-computed res1 feeds both downstream consumers — no host copy, no intermediate launch.
New generic primitives (each with main + Makefile + lit tests)
Llama integration
Compiler dependency — gating this PR
The new ELF emits a rank-reducing 2D→1D `memref.subview` at the top of `@o_gemv_ffn` that is rejected by the current pinned mlir-aie at `aie.dma_bd` lowering. This PR is gated on mlir-aie #3121 (generalizes `traceSubviewToBlockArgument` to N-D rank-reducing subviews) merging.
Once #3121 lands, this PR adds a follow-up commit bumping `utils/clone-mlir-aie.sh` HASH to a commit including the fix, and moves out of draft.
Validation
Known follow-ups (not in this PR)
Test plan