Skip to content

[llama32_1b] 3-launch o_gemv_ffn for decode (+17% tok/s)#1631

Merged
erwei-xilinx merged 3 commits into
Xilinx:mainfrom
erwei-xilinx:decode-3launch-o-gemv-ffn
May 30, 2026
Merged

[llama32_1b] 3-launch o_gemv_ffn for decode (+17% tok/s)#1631
erwei-xilinx merged 3 commits into
Xilinx:mainfrom
erwei-xilinx:decode-3launch-o-gemv-ffn

Conversation

@erwei-xilinx
Copy link
Copy Markdown
Collaborator

Summary

Replaces the 8-launch o_gemv_ffn decode kernel with a 3-launch design that stitches three primitives in one ELF and routes the post-attention residual through a row-0 subview of a packed 2D arg.

End-to-end on NPU2 with real Llama-3.2-1B-Instruct:

  • Per-token decode: 91 ms → 78 ms (−14%)
  • Tokens/sec: 10.93 → 12.78 (+17%)
  • Response matches baseline byte-for-byte ("What is the capital of France?" → "The capital of France is Paris.", same EOT stop).

Design

Three sub-launches inside one ELF, sharing core columns and sequenced:

Stage 1 (matvec_2tile_add):  res1 = wo @ attn_out + x_residual  →  arg6[0]
Stage 2 (matvec_swiglu_rms): swiglu = silu(gate @ rms_norm(arg6)) * up
                              with gate/up interleaved into arg7
Stage 3 (matvec_2tile_add):  output = wdown @ swiglu + res1     ←  arg6[0]

`arg6` is a packed `memref<2 × emb_dim × bf16>`: row 0 = res1 (NPU-computed by stage 1), row 1 = ffn_norm_w (host pre-loaded once). Stage 1's D-output and stage 3's R-input both reference a row-0 subview of arg6, so the same NPU-computed res1 feeds both downstream consumers — no host copy, no intermediate launch.

New generic primitives (each with main + Makefile + lit tests)

File What
`matrix_vector_multiplication/bf16_cascade/matvec_2tile_add.py` Two-tile-per-col matvec + residual add via intra-col `npu_cascade` (matvec_herd north, add herd south).
`decode_ffn_swiglu/matvec_swiglu_rms.py` GEMV with weighted-RMSNorm input (packed [2, K]) and fused SwiGLU output (M/2 elements) over interleaved gate/up rows.
`bf16_cascade/mv_bf16.cc` Micro-kernels for matvec_2tile_add: `matvec_vectorized_bf16`, `zero_vectorized_bf16`, `partial_plus_r_bf16`.

Llama integration

Compiler dependency — gating this PR

The new ELF emits a rank-reducing 2D→1D `memref.subview` at the top of `@o_gemv_ffn` that is rejected by the current pinned mlir-aie at `aie.dma_bd` lowering. This PR is gated on mlir-aie #3121 (generalizes `traceSubviewToBlockArgument` to N-D rank-reducing subviews) merging.

Once #3121 lands, this PR adds a follow-up commit bumping `utils/clone-mlir-aie.sh` HASH to a commit including the fix, and moves out of draft.

Validation

Test Result
`matvec_2tile_add` lit, M=K=2048 correlation 0.999994
`matvec_2tile_add` lit, M=2048, K=8192 correlation 0.999991
`matvec_swiglu_rms` lit, M=16384, K=2048 PASS
`llama32_1b` `make compile` + `make run` on real Llama-3.2-1B-Instruct "The capital of France is Paris." at 12.78 tok/s

Known follow-ups (not in this PR)

  1. Documentation under `llama32_1b/docs/` still references the 8-launch design (`mv_k8192.o`, "8 launches"). A docs refresh is a separate change.
  2. `make verify` is unchanged (it validates prefill, not the decode path being changed here).
  3. Compile time for the 3-launch `o_gemv_ffn` is ~3.5 min (Peano builds the inlined cascade kernels). The cached ELF is reused via the existing manifest, so this is only paid on a fresh build.

Test plan

  • `lit` invocation of the two `run_2tile_add_npu2_*` tests under `bf16_cascade/`
  • `lit` invocation of `run_npu2_peano.lit` under `decode_ffn_swiglu/`
  • `make compile` + `make run` in `programming_examples/llama32_1b/` produces the expected response
  • Hold draft until mlir-aie #3121 merges + bump commit added

@erwei-xilinx erwei-xilinx force-pushed the decode-3launch-o-gemv-ffn branch 2 times, most recently from b98ea43 to c04003e Compare May 28, 2026 23:43
erwei-xilinx and others added 2 commits May 28, 2026 16:43
Replaces the 8-launch o_gemv_ffn decode kernel with a 3-launch design
that stitches three primitives in one ELF and routes the post-attention
residual through a row-0 subview of a packed 2D arg so a single
NPU-computed value feeds two downstream consumers without a host copy
or an intermediate launch.

Stages:
  1. matvec_2tile_add: res1 = wo @ attn_out + x_residual,
                       written into arg6[0]
  2. matvec_swiglu_rms: swiglu = silu(gate @ rms_norm(arg6)) * up,
                        with gate/up rows interleaved into arg7
  3. matvec_2tile_add: output = wdown @ swiglu + res1,
                       re-reading res1 from arg6[0]

End-to-end Llama-3.2-1B decode on NPU2: 91 ms/tok → 78 ms/tok
(10.93 → 12.78 tok/s, +17%). Generated text matches the baseline
(same response, same EOT stop).

Files

  Two new generic primitives, each with __main__ + Makefile + lit tests:
  - matrix_vector_multiplication/bf16_cascade/matvec_2tile_add.py
    Two-tile-per-col matvec + residual add via intra-col npu_cascade
    (matvec_herd north, add herd south).
  - decode_ffn_swiglu/matvec_swiglu_rms.py
    GEMV with weighted RMSNorm input (packed [2, K]) and fused SwiGLU
    output (M/2 elements) over interleaved gate/up rows.

  Llama integration:
  - llama32_1b/multi_launch_builder/o_gemv_ffn_multi.py
    Replaces the 8-launch builder. Three sub-launches stitched with
    arg6[0] subview routing; 15-arg ABI kept for caller compatibility
    (dead args become zero placeholders).
  - llama32_1b/llama32_1b_decode.py + llama32_1b_inference.py
    Decode call site uses the new arg6 packed RMS input + arg7
    interleaved w_gateup. Preload builds the interleaved w_gateup once
    per layer and frees the original wgate/wup arrays afterward (saves
    ~1 GB host RAM across 16 layers).
  - llama32_1b/kernel_builder/external_kernels.py
    Drops compile_mv_k8192 (8-launch artifact); adds compile_mv_bf16
    for the new matvec_2tile_add micro-kernel.
  - llama32_1b/kernel_builder/stitching.py
    Adds _extract_channel_decls, #set affine alias handling in extract
    + rename, and arg_aliases parameter on _fix_launch_func_args so a
    sub-launch's func arg can resolve to an arbitrary SSA value in the
    combined func body (used here for the arg6[0] subview).

Compiler dependency

The new ELF emits a rank-reducing 2D→1D memref.subview at the top of
@o_gemv_ffn that is rejected by the current pinned mlir-aie at
aie.dma_bd lowering. This PR is gated on mlir-aie #3121 (generalizes
traceSubviewToBlockArgument to N-D rank-reducing subviews) being
merged. utils/clone-mlir-aie.sh HASH will be bumped in a follow-up
commit once a mlir-aie tag including #3121 exists.

Validation

  - matvec_2tile_add lit tests (M=K=2048 and M=2048, K=8192):
    correlation 0.999994 / 0.999991 on NPU2.
  - matvec_swiglu_rms lit test (M=16384, K=2048): PASS on NPU2.
  - Full llama32_1b inference (`make compile` + `make run`) on real
    Llama-3.2-1B-Instruct: prompt "What is the capital of France?"
    returns "The capital of France is Paris." at 12.78 tok/s.

Documentation under llama32_1b/docs/ still references 8-launch design
details (mv_k8192.o, "8 launches"); a docs follow-up is appropriate
once this lands.
@erwei-xilinx erwei-xilinx marked this pull request as ready for review May 29, 2026 23:31
@erwei-xilinx erwei-xilinx requested a review from jgmelber as a code owner May 29, 2026 23:31
Copilot AI review requested due to automatic review settings May 29, 2026 23:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR replaces the Llama 3.2 1B decode o_gemv_ffn path with a stitched 3-launch design using new BF16 matvec/add and fused RMSNorm+SwiGLU primitives.

Changes:

  • Adds new generic NPU2/Peano primitives and lit coverage for matvec_2tile_add and matvec_swiglu_rms.
  • Reworks o_gemv_ffn_multi.py to stitch the three stages and route the residual through arg6[0].
  • Updates decode preload/runtime plumbing to use packed RMS input and interleaved gate/up weights.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
programming_examples/matrix_vector_multiplication/bf16_cascade/run_2tile_add_npu2_2048x8192_peano.lit Adds large-K lit coverage for the 2-tile matvec+add primitive.
programming_examples/matrix_vector_multiplication/bf16_cascade/run_2tile_add_npu2_2048x2048_peano.lit Adds base-shape lit coverage for the 2-tile matvec+add primitive.
programming_examples/matrix_vector_multiplication/bf16_cascade/mv_bf16.cc Adds BF16 micro-kernels for zeroing, matvec accumulation, and residual add.
programming_examples/matrix_vector_multiplication/bf16_cascade/matvec_2tile_add.py Adds AIR builder and standalone runner for matvec plus residual add.
programming_examples/matrix_vector_multiplication/bf16_cascade/Makefile Adds build/run targets for the new 2-tile matvec+add primitive.
programming_examples/llama32_1b/multi_launch_builder/o_gemv_ffn_multi.py Replaces the 8-launch decode FFN stitcher with the new 3-stage design.
programming_examples/llama32_1b/llama32_1b_inference.py Prepares packed RMS buffers and interleaved gate/up weights for decode.
programming_examples/llama32_1b/llama32_1b_decode.py Updates decode invocation to match the new 15-arg packed ABI.
programming_examples/llama32_1b/kernel_builder/stitching.py Extends stitching helpers for channel decls, affine sets, and SSA aliases.
programming_examples/llama32_1b/kernel_builder/external_kernels.py Replaces the old K8192 GEMV object build with mv_bf16.o.
programming_examples/decode_ffn_swiglu/run_npu2_peano.lit Adds lit coverage for fused RMSNorm+SwiGLU GEMV.
programming_examples/decode_ffn_swiglu/matvec_swiglu_rms.py Adds fused weighted RMSNorm, interleaved gate/up GEMV, and SwiGLU output primitive.
programming_examples/decode_ffn_swiglu/Makefile Adds build/run targets for the fused decode FFN primitive.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread programming_examples/llama32_1b/kernel_builder/external_kernels.py
Comment thread programming_examples/decode_ffn_swiglu/matvec_swiglu_rms.py
- cache.py: stage mv_bf16.o (used by o_gemv_ffn stages 1+3 via
  matvec_2tile_add); drop stale mv_k8192.o.
- matvec_swiglu_rms.py: deallocate l1_rms_in_data alongside the other
  per-segment L1 buffers.
- bf16_cascade/Makefile: forward TILE_M_2T / K_CHUNK_2T as -DDIM_M /
  -DDIM_K to mv_bf16.cc so the linked micro-kernel matches the
  MLIR-side tile shape under overrides.

Verified: llama32_1b decode "Paris" answer at 12.83 tok/s; swiglu_rms
lit PASS; matvec_2tile_add lit correlation 0.999994 PASS.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@erwei-xilinx erwei-xilinx added this pull request to the merge queue May 30, 2026
Merged via the queue into Xilinx:main with commit 462bf13 May 30, 2026
27 checks passed
@erwei-xilinx erwei-xilinx deleted the decode-3launch-o-gemv-ffn branch May 30, 2026 04:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants