From f91e6addfdf926360ae8e539afa47a7ed8fe847d Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 6 Apr 2026 10:13:19 -0700 Subject: [PATCH 1/6] Add KV cache prefill flash attention example for AIE2P Add a new programming example that demonstrates fused flash attention with KV cache write-back on AIE2P NPU. This extends the existing kernel_fusion_based flash attention with K cache prefill capability, where RoPE'd K data is written back to DDR during attention computation. Key design features: - L1-to-L3 direct K write-back path bypassing memtile to avoid DMA channel congestion - Dedicated staging buffer to prevent DMA race conditions between K receive and write-back - Un-tiling DMA strides to convert 8x8 blocked L1 layout back to row-major for the K cache - Support for GQA (grouped query attention) with configurable head counts - Causal masking support - C++ test executable for ELF-based profiling workflow Co-Authored-By: Claude Opus 4.6 --- .../flash_attention/kv_cache_prefill/Makefile | 100 ++ .../kv_cache_prefill/attn_npu2.cc | 642 +++++++ .../kv_cache_prefill/attn_npu2.py | 1541 +++++++++++++++++ .../run_npu2_makefile_peano_elf.lit | 10 + .../kv_cache_prefill/run_test.sh | 20 + .../kv_cache_prefill/test_elf_npu2.cpp | 252 +++ 6 files changed, 2565 insertions(+) create mode 100644 programming_examples/flash_attention/kv_cache_prefill/Makefile create mode 100644 programming_examples/flash_attention/kv_cache_prefill/attn_npu2.cc create mode 100644 programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py create mode 100644 programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit create mode 100755 programming_examples/flash_attention/kv_cache_prefill/run_test.sh create mode 100644 programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp diff --git a/programming_examples/flash_attention/kv_cache_prefill/Makefile b/programming_examples/flash_attention/kv_cache_prefill/Makefile new file mode 100644 index 000000000..f624bf40d --- /dev/null +++ b/programming_examples/flash_attention/kv_cache_prefill/Makefile @@ -0,0 +1,100 @@ +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +kerneldir := $(srcdir)/../kernel_fusion_based + +# Attention parameters +LK ?= 512 +LKP ?= 64 +LQ ?= 512 +LQP ?= 256 +DK ?= 64 +DV ?= 64 +NUM_HEADS ?= 2 +NUM_KV_HEADS ?= $(NUM_HEADS) + +# Derived: kernel tile size = LQP / num_q_tiles (4) +NUM_Q_TILES ?= 4 +LQP_TILE := $(shell echo $$(($(LQP) / $(NUM_Q_TILES)))) + + +# Determine build dir based on whether PEANO_INSTALL_DIR is set +ifdef PEANO_INSTALL_DIR + BUILD_DIR := build_peano +else + BUILD_DIR := build_chess +endif + +AIEOPT_DIR = $(shell realpath $(dir $(shell which aie-opt))/..) +WARNING_FLAGS = -Wno-parentheses -Wno-attributes -Wno-macro-redefined -Wno-empty-body +PEANOWRAP2P_FLAGS = -Os -std=c++20 --target=aie2p-none-unknown-elf ${WARNING_FLAGS} -DNDEBUG -I ${AIEOPT_DIR}/include + +all: run + +print: + ${powershell} python3 ${srcdir}/attn_npu2.py -p --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) + +run: compile-kernel + mkdir -p $(BUILD_DIR) + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && ${powershell} python3 ${srcdir}/attn_npu2.py --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) $(EXTRA_PY_FLAGS) + +# Profile ELF: compile elf and run with C++ test executable for elf format +# Usage: make profile [LK=...] [LQ=...] etc. +profile: compile-kernel build-test-exe + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && ${powershell} python3 ${srcdir}/attn_npu2.py \ + --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) \ + --compile-mode compile-only + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && ./test_elf_npu2.exe -e air.elf -k "main:attention_bf16" \ + --lq $(LQ) --lk $(LK) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) + +build-test-exe: + @GPP=$$( \ + for bin in /usr/bin/g++-*; do \ + ver=$$(echo $$bin | grep -oE '[0-9]+$$'); \ + if [ "$$ver" -ge 13 ] 2>/dev/null; then \ + echo "$$ver $$bin"; \ + fi; \ + done | sort -nr | head -n1 | awk '{print $$2}' \ + ); \ + if [ -z "$$GPP" ]; then \ + echo "Error: No g++ version >= 13 found in /usr/bin."; \ + exit 1; \ + fi; \ + if [ -z "$$XILINX_XRT" ]; then \ + echo "Error: XILINX_XRT environment variable not set. Please make sure to have sourced xrt/setup.sh."; \ + exit 1; \ + fi; \ + if [ -z "$(AIEOPT_DIR)" ]; then \ + echo "Error: AIEOPT_DIR environment variable not set. Please make sure to have sourced utils/env_setup.sh."; \ + exit 1; \ + fi; \ + echo "Using compiler: $$GPP"; \ + mkdir -p $(BUILD_DIR); \ + cd $(BUILD_DIR) && $$GPP ${srcdir}/test_elf_npu2.cpp -o test_elf_npu2.exe -std=c++23 -Wall \ + -I$$XILINX_XRT/include -L$$XILINX_XRT/lib \ + -I$(AIEOPT_DIR)/runtime_lib/x86_64/test_lib/include \ + -L$(AIEOPT_DIR)/runtime_lib/x86_64/test_lib/lib \ + -luuid -lxrt_coreutil -lrt -lstdc++ -ltest_utils + +# Compile local kernel (with RoPE support) +compile-kernel: + mkdir -p $(BUILD_DIR) + @if [ -n "$(PEANO_INSTALL_DIR)" ]; then \ + echo "Detected PEANO_INSTALL_DIR from environment: $(PEANO_INSTALL_DIR)"; \ + if [ -x "$(PEANO_INSTALL_DIR)/bin/clang++" ]; then \ + echo "Using clang++ from PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR)"; \ + $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2P_FLAGS} -DBIT_WIDTH=8 -I${kerneldir} -c ${srcdir}/attn_npu2.cc -o $(BUILD_DIR)/attn_npu2.o -Dlqp=$(LQP_TILE) -Dlkp=$(LKP) -Ddk=$(LKP) -Ddk_full=$(DK) -Ddv=$(LKP) -Ddv_full=$(DV) -DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16 -DROUND_CONV_EVEN $(EXTRA_KERNEL_FLAGS); \ + else \ + echo "Error: invalid PEANO_INSTALL_DIR, clang++ not found."; \ + exit 1; \ + fi; \ + elif command -v xchesscc_wrapper >/dev/null 2>&1; then \ + echo "Using xchesscc_wrapper from PATH"; \ + cd $(BUILD_DIR) && ${powershell} xchesscc_wrapper aie2p -I${kerneldir} -c ${srcdir}/attn_npu2.cc -o attn_npu2.o -Dlqp=$(LQP_TILE) -Dlkp=$(LKP) -Ddk=$(LKP) -Ddk_full=$(DK) -Ddv=$(LKP) -Ddv_full=$(DV); \ + else \ + echo "Error: Neither PEANO_INSTALL_DIR nor xchesscc_wrapper found."; \ + exit 1; \ + fi + +clean: + rm -rf $(BUILD_DIR) __pycache__ diff --git a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.cc b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.cc new file mode 100644 index 000000000..228d43e03 --- /dev/null +++ b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.cc @@ -0,0 +1,642 @@ +//===- attn_npu2.cc ------------------------------------------*- C++ -*-===// +// +// SPDX-License-Identifier: MIT +// Copyright (C) 2025, Advanced Micro Devices, Inc. +// +// Flash attention kernel for KV cache prefill. +// Forked from kernel_fusion_based/attn_npu2.cc. +// RoPE is applied on the host before sending Q/K to the NPU. +// TODO: Add on-chip rope_sincos that computes sin/cos directly on AIE2P. +// +//===----------------------------------------------------------------------===// + +#define NOCPP + +#include +#include +#include +#include + +#define REL_WRITE 0 +#define REL_READ 1 + +#include + +#include "zero.cc" + +// Default values if not provided by Makefile +#ifndef lqp +#define lqp 32 +#endif + +#ifndef lkp +#define lkp 96 +#endif + +#ifndef dk +#define dk 64 +#endif + +#ifndef dv +#define dv 64 +#endif + +#ifndef dv_full +#define dv_full dv +#endif + +// Column-major B matmul with compile-time transpose control. +// transpose_b: true = apply aie::transpose before mac (K DMA: inner [n_in, +// k_in]) +// false = load B as-is, hardware mul_8x8_8x8T transposes (V DMA: +// inner [k_in, n_in]) +// A and C are always column-major tiled. +template +static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA, + const T_in *__restrict pB, + T_out *__restrict pC) { + + using MMUL = aie::mmul; + + event0(); + + for (unsigned z = 0; z < rowA; z += 2) + chess_prepare_for_pipelining chess_loop_range(2, ) { + T_out *__restrict pC1 = pC + (z)*MMUL::size_C; + T_out *__restrict pC2 = pC + ((z + 1)) * MMUL::size_C; + + for (unsigned j = 0; j < colB; j += 2) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + const T_in *__restrict pA1 = pA + (z)*MMUL::size_A; + const T_in *__restrict pA2 = pA + ((z + 1)) * MMUL::size_A; + const T_in *__restrict pB1 = pB + (j)*colA * MMUL::size_B; + const T_in *__restrict pB2 = pB + (j + 1) * colA * MMUL::size_B; + + aie::vector acc_C00 = + aie::load_v(pC1); + aie::vector acc_C01 = + aie::load_v(pC1 + MMUL::size_C * rowA); + aie::vector acc_C10 = + aie::load_v(pC2); + aie::vector acc_C11 = + aie::load_v(pC2 + MMUL::size_C * rowA); + + MMUL C00(acc_C00); + MMUL C01(acc_C01); + MMUL C10(acc_C10); + MMUL C11(acc_C11); + + for (unsigned i = 0; i < colA; ++i) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + aie::vector A0 = + aie::load_v(pA1); + pA1 += rowA * MMUL::size_A; + aie::vector A1 = + aie::load_v(pA2); + pA2 += rowA * MMUL::size_A; + + aie::vector B0, B1; + if constexpr (transpose_b) { + // K DMA k-major block layout: block (n=j, k=i) at i*colB+j. + // Sub-tile elements are [n_in, k_in], transpose to [k_in, + // n_in]. + const T_in *__restrict pBk0 = + pB + (i * colB + j) * MMUL::size_B; + const T_in *__restrict pBk1 = + pB + (i * colB + (j + 1)) * MMUL::size_B; + B0 = aie::transpose(aie::load_v(pBk0), t, s); + B1 = aie::transpose(aie::load_v(pBk1), t, s); + } else { + // V DMA inner layout is [k_in, n_in] — already correct for + // hardware mul_8x8_8x8T, no software transpose needed. + B0 = aie::load_v(pB1); + B1 = aie::load_v(pB2); + } + pB1 += MMUL::size_B; + pB2 += MMUL::size_B; + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + } + + aie::store_v(pC1, C00.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C01.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC2, C10.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C11.template to_vector()); + pC2 += MMUL::size_C * rowA; + } + } + + event1(); +} + +// bf16 MatMul kernel with bf16 outputs. +// transpose_b: controls whether B blocks are software-transposed before mac. +template +static inline void +matmul_vectorized_8x8x8_bf16_bf16(const bfloat16 *__restrict pA, + const bfloat16 *__restrict pB, + bfloat16 *__restrict pC) { + constexpr int r = 8; + constexpr int s = 8; + constexpr int t = 8; + static_assert(m % (2 * r) == 0); // 'm' dimension + static_assert(k % s == 0); // 'k' dimension + static_assert(n % (2 * t) == 0); // 'n' dimension + + return matmul_vectorized_2x2_mmul(pA, pB, pC); +} + +// Combined scale: log2e / sqrt(dk_full). Applies 1/sqrt(dk) inside softmax +// with accfloat precision, avoiding bf16 truncation of Q. +// dk is the tile dimension (= lkp), dk_full is the full key dimension. +// When dk_full == dk (default), sqrt(64) = 8.0 — no change. +#include + +#ifndef dk_full +#define dk_full dk +#endif + +constexpr double constexpr_sqrt_dk = (dk_full == 64) ? 8.0 + : (dk_full == 128) ? 11.313708498984761 + : (dk_full == 256) ? 16.0 + : (dk_full == 512) ? 22.627416997969522 + : 8.0; +static_assert(dk_full == 64 || dk_full == 128 || dk_full == 256 || + dk_full == 512, + "Unsupported dk_full value: update constexpr_sqrt_dk"); + +#define log2e (1.44269504089 / constexpr_sqrt_dk) + +__attribute__((always_inline)) v8bfloat16 getExpBf16(v8bfloat16 x) { + + constexpr int VecLen = 8; + + // Calculate the e^(x) function as 2^(log2e * x) + aie::vector input_bf16 = x; + aie::accum exp_in; + aie::vector exp_val; + aie::vector log2e_vec = + aie::broadcast(log2e); + + exp_in = aie::mul(input_bf16, log2e_vec); + exp_val = aie::exp2(exp_in.to_vector()); + return exp_val; +} + +extern "C" { + +// Set rounding mode at the start of every extern C function. +// Setting conv_even rounding in every softmax function; without this, +// softmax intermediates use the system default rounding mode, +// causing ~44% errors at val_range=4 due to rounding noise +// amplified by softmax's peaked distribution. +#ifdef ROUND_CONV_EVEN +#define SET_ROUNDING() ::aie::set_rounding(::aie::rounding_mode::conv_even) +#else +#define SET_ROUNDING() /* no-op */ +#endif + +// Copy tile_size_q×dk elements from src to dst (single-pass vector copy) +void copy_tile(bfloat16 *src, bfloat16 *dst) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int num_elems = lqp * dk; + bfloat16 *__restrict ps = src; + bfloat16 *__restrict pd = dst; + for (unsigned j = 0; j < num_elems / VecLen; j++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = aie::load_v(ps); + aie::store_v(pd, v); + ps += VecLen; + pd += VecLen; + } +} + +void matmul_a_b_bf16(bfloat16 *a_in, bfloat16 *b_in, bfloat16 *out) { + SET_ROUNDING(); + // Buffer shapes: + // A: [lqp, dk] = [32, 64] + // B: [lkp, dk] = [96, 64] (K row-major, aie::transpose per block) + // Out: [lqp, lkp] = [32, 96] + matmul_vectorized_8x8x8_bf16_bf16(a_in, b_in, out); +} + +void matmul_g_b_bf16(bfloat16 *g_in, bfloat16 *b_in, bfloat16 *out) { + SET_ROUNDING(); + // Buffer shapes: + // G: [lqp, lkp] = [32, 96] + // B: [lkp, dv] = [96, 64] + // Out: [lqp, dv] = [32, 64] + // G@V: V DMA inner layout is [k_in, n_in], so NO software transpose needed. + // The hardware mul_8x8_8x8T already transposes B internally. + matmul_vectorized_8x8x8_bf16_bf16( + g_in, b_in, out); +} + +void zero_fill_gp_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, dv] = [32, 64] + zero_vectorized(c_out); +} + +void zero_fill_sp_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, 1] = [32, 1] + zero_vectorized(c_out); +} + +void zero_fill_g_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, lkp] = [32, 96] + zero_vectorized(c_out); +} + +void neg_inf_fill_up_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, 1] = [32, 1] + neg_inf_vectorized(c_out); +} + +void max_g_bf16(bfloat16 *in, bfloat16 *out) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + + bfloat16 *__restrict pOut = out; + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + aie::vector max_vec = + aie::broadcast(lowest_val); + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = + aie::load_v(in + base + cb * block_stride); + max_vec = aie::max(max_vec, v); + } + aie::vector r0 = max_vec.extract<8>(0); + aie::vector r1 = max_vec.extract<8>(1); + aie::vector r2 = max_vec.extract<8>(2); + aie::vector r3 = max_vec.extract<8>(3); + pOut[half * 4 + 0] = aie::reduce_max(r0); + pOut[half * 4 + 1] = aie::reduce_max(r1); + pOut[half * 4 + 2] = aie::reduce_max(r2); + pOut[half * 4 + 3] = aie::reduce_max(r3); + } + pOut += RowsPerBlock; + } +} + +void maximum_up_u_bf16(bfloat16 *up, bfloat16 *u) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int num_elems = lqp; + bfloat16 *__restrict pu = u; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector up_temp = aie::load_v(up + i); + aie::vector u_temp = aie::load_v(pu); + u_temp = aie::max(up_temp, u_temp); + aie::store_v(pu, u_temp); + pu += VecLen; + } +} + +void exp_g_minus_u(bfloat16 *u, bfloat16 *g) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + aie::vector log2e_vec16 = + aie::broadcast((bfloat16)log2e); + aie::vector lowest_vec = + aie::broadcast(lowest_val); + + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + int row_start = rb * RowsPerBlock + half * 4; + aie::vector u0 = aie::broadcast(u[row_start]); + aie::vector u1 = + aie::broadcast(u[row_start + 1]); + aie::vector u2 = + aie::broadcast(u[row_start + 2]); + aie::vector u3 = + aie::broadcast(u[row_start + 3]); + aie::vector u_vec; + u_vec.insert(0, u0); + u_vec.insert(1, u1); + u_vec.insert(2, u2); + u_vec.insert(3, u3); + + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(g + off); + v = aie::sub(v, u_vec); + v = aie::max(v, lowest_vec); + aie::vector lo = v.extract<16>(0); + aie::vector hi = v.extract<16>(1); + lo = + aie::exp2(aie::mul(lo, log2e_vec16).to_vector()); + hi = + aie::exp2(aie::mul(hi, log2e_vec16).to_vector()); + v.insert(0, lo); + v.insert(1, hi); + aie::store_v(g + off, v); + } + } + } +} + +void exp_up_minus_u(bfloat16 *up, bfloat16 *u, bfloat16 *r) { + SET_ROUNDING(); + constexpr int VecLen = 16; + constexpr int num_elems = lqp; + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + aie::vector lowest_vec = + aie::broadcast(lowest_val); + bfloat16 *__restrict pr = r; + bfloat16 *__restrict pu = u; + bfloat16 *__restrict pup = up; + aie::vector log2e_vec = + aie::broadcast((bfloat16)log2e); + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector uTemp = aie::load_v(pu); + aie::vector upTemp = aie::load_v(pup); + aie::vector diff = aie::sub(upTemp, uTemp); + diff = aie::max(diff, lowest_vec); + aie::vector exp_val = + aie::exp2(aie::mul(diff, log2e_vec).to_vector()); + aie::store_v(pr, exp_val); + pr += VecLen; + pu += VecLen; + pup += VecLen; + } +} + +void mul_r_gp(bfloat16 *r, bfloat16 *gp) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = dv / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + int row_start = rb * RowsPerBlock + half * 4; + aie::vector r0 = aie::broadcast(r[row_start]); + aie::vector r1 = + aie::broadcast(r[row_start + 1]); + aie::vector r2 = + aie::broadcast(r[row_start + 2]); + aie::vector r3 = + aie::broadcast(r[row_start + 3]); + aie::vector r_vec; + r_vec.insert(0, r0); + r_vec.insert(1, r1); + r_vec.insert(2, r2); + r_vec.insert(3, r3); + + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(gp + off); + aie::accum acc = aie::mul(v, r_vec); + aie::store_v(gp + off, acc.to_vector()); + } + } + } +} + +void sum_g(bfloat16 *g, bfloat16 *s) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + bfloat16 *__restrict ps = s; + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + aie::accum sum_acc = aie::zeros(); + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = + aie::load_v(g + base + cb * block_stride); + sum_acc = aie::add(sum_acc, v); + } + aie::vector sum_v = sum_acc.to_vector(); + aie::vector r0 = sum_v.extract<8>(0); + aie::vector r1 = sum_v.extract<8>(1); + aie::vector r2 = sum_v.extract<8>(2); + aie::vector r3 = sum_v.extract<8>(3); + ps[half * 4 + 0] = (bfloat16)aie::reduce_add(r0); + ps[half * 4 + 1] = (bfloat16)aie::reduce_add(r1); + ps[half * 4 + 2] = (bfloat16)aie::reduce_add(r2); + ps[half * 4 + 3] = (bfloat16)aie::reduce_add(r3); + } + ps += RowsPerBlock; + } +} + +void accum_sp_r_s(bfloat16 *sp, bfloat16 *r, bfloat16 *s) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int num_elems = lqp; + bfloat16 *__restrict pr = r; + bfloat16 *__restrict ps = s; + bfloat16 *__restrict psp = sp; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector rTemp = aie::load_v(pr); + aie::vector spTemp = aie::load_v(psp); + aie::accum accTemp = aie::mul(rTemp, spTemp); + accTemp = aie::add(accTemp, aie::load_v(ps)); + aie::vector sTemp = to_v32bfloat16(accTemp); + aie::store_v(ps, sTemp); + pr += VecLen; + ps += VecLen; + psp += VecLen; + } +} + +void vector_copy_32elems(const int offset, const bfloat16 *__restrict inputs, + bfloat16 *__restrict outputs) { + constexpr int VecLen = 32; + constexpr int num_elems = lqp; + const bfloat16 *__restrict pIn = inputs; + bfloat16 *__restrict pOut = outputs + offset; + for (unsigned j = 0; j < num_elems / VecLen; j++) { + aie::vector vec = aie::load_v(pIn); + pIn += VecLen; + aie::store_v(pOut, vec); + pOut += VecLen; + } +} + +void div_gp_sp(bfloat16 *sp, bfloat16 *gp) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = dv / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + int row_start = rb * RowsPerBlock + half * 4; + aie::vector sp0 = aie::broadcast(sp[row_start]); + aie::vector sp1 = + aie::broadcast(sp[row_start + 1]); + aie::vector sp2 = + aie::broadcast(sp[row_start + 2]); + aie::vector sp3 = + aie::broadcast(sp[row_start + 3]); + aie::vector sp_vec; + sp_vec.insert(0, sp0); + sp_vec.insert(1, sp1); + sp_vec.insert(2, sp2); + sp_vec.insert(3, sp3); + aie::vector sp_inv = aie::inv(sp_vec); + + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(gp + off); + aie::accum acc = aie::mul(v, sp_inv); + aie::store_v(gp + off, acc.to_vector()); + } + } + } +} + +// Fused softmax: delegates to existing optimized VecLen=32 kernels. +// On return: up=new_max, sp=sum(exp(G)), r=rescale_factor, G=exp(G-max). +void fused_softmax(bfloat16 *g, bfloat16 *up, bfloat16 *sp, bfloat16 *r) { + SET_ROUNDING(); + max_g_bf16(g, r); + maximum_up_u_bf16(up, r); + exp_g_minus_u(r, g); + exp_up_minus_u(up, r, sp); + vector_copy_32elems(0, r, up); + vector_copy_32elems(0, sp, r); + sum_g(g, sp); +} + +void add_gp_g(bfloat16 *gp, bfloat16 *g) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int num_elems = lqp * dv; + bfloat16 *__restrict gp_ptr = gp; + bfloat16 *__restrict g_ptr = g; + for (unsigned j = 0; j < num_elems / VecLen; j++) { + aie::vector gp_vec = aie::load_v(gp_ptr); + aie::vector g_vec = aie::load_v(g_ptr); + aie::accum acc(gp_vec); + acc = aie::add(acc, g_vec); + aie::store_v(g_ptr, acc.to_vector()); + gp_ptr += VecLen; + g_ptr += VecLen; + } +} + +// Apply causal mask to QK scores in-place. +void apply_causal_mask(bfloat16 *g, int32_t q_block_idx, int32_t kv_block_idx) { + SET_ROUNDING(); + uint16_t neg_inf_u16 = (uint16_t)0xff80; + bfloat16 neg_inf_val = *(bfloat16 *)&neg_inf_u16; + + if (kv_block_idx > q_block_idx) { + constexpr int VecLen = 32; + aie::vector neg_inf_vec = + aie::broadcast(neg_inf_val); + bfloat16 *p = g; + for (int i = 0; i < lqp * lkp; i += VecLen) { + aie::store_v(p, neg_inf_vec); + p += VecLen; + } + return; + } + + if (kv_block_idx < q_block_idx) { + return; + } + + constexpr int BlkDim = 8; + aie::vector mask_vec = + aie::broadcast(neg_inf_val); + + for (int row = 0; row < lqp; row++) { + int mask_start = row + 1; + int row_blk = row / BlkDim; + int row_in = row % BlkDim; + + for (int col_blk = 0; col_blk < lkp / BlkDim; col_blk++) { + int col_start = col_blk * BlkDim; + int off = col_blk * (lqp * BlkDim) + row_blk * (BlkDim * BlkDim) + + row_in * BlkDim; + + aie::vector orig = aie::load_v(g + off); + + if (mask_start >= lkp) { + aie::store_v(g + off, orig); + } else if (col_start >= mask_start) { + aie::store_v(g + off, mask_vec); + } else if (col_start + BlkDim > mask_start) { + uint32_t sel_bits = 0; + for (int c = 0; c < BlkDim; c++) { + if (col_start + c >= mask_start) { + sel_bits |= (1u << c); + } + } + aie::mask sel(sel_bits); + aie::store_v(g + off, aie::select(orig, mask_vec, sel)); + } else { + aie::store_v(g + off, orig); + } + } + } +} + +} // extern "C" diff --git a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py new file mode 100644 index 000000000..14c9d1fb1 --- /dev/null +++ b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py @@ -0,0 +1,1541 @@ +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +"""Fused Flash Attention + KV Cache Write-Back (Single Launch). + +Single-launch design that fuses flash attention and KV cache write-back +into one AIE program. When RoPE is enabled, Q and K are pre-RoPE'd by +the host before being sent to the NPU — the NPU performs attention on +already-rotated data. The K data is written back to L3 K cache via DMA +from tx=0 tiles. + +TODO: Replace host-side RoPE with on-chip rope_sincos kernel that +computes sin/cos directly on AIE2P without needing a LUT. + +DMA channel strategy (2 S2MM + 2 MM2S per compute tile): + S2MM 0: QK channel (Q and K via L2 relay) + S2MM 1: V (per-stage via memtile) + MM2S 0: Cascade or output (ty=0) + MM2S 1: K cache write-back (tx=0 only) + +Channel layout: + QKIn_s/QK2L1_s: per-stage memtile relay with horizontal broadcast + VIn_s/V2L1_s: per-stage memtile relay with horizontal broadcast + cascade_gp/cascade_up/cascade_sp: 2D cascade channels (per-segment) + Gp2L2/GpOut: output from ty=0 tiles + KWB: K cache write-back (tx=0 tiles send K directly to L3 via tile DMA) + +Note: V cache is NOT written by the NPU. The host should copy input V +directly to the V cache buffer if needed. +""" + +import argparse +import os +from math import sqrt + +import numpy as np + +import air +from air.ir import * +from air.dialects.affine import apply as affine_apply +from air.dialects.air import * +from air.dialects.air import channel +from air.dialects.arith import ConstantOp +from air.dialects.memref import AllocOp, CollapseShapeOp, DeallocOp, load, store +from air.dialects.func import FuncOp, CallOp +from air.dialects.scf import for_ as scf_range, yield_ +from air.dialects import scf, affine, arith + + +@module_builder +def build_module( + lk=512, + lkp=64, + lq=512, + lqp=256, + dk=64, + dv=64, + num_q_tiles=4, + num_cascade_stages=4, + num_heads=2, + num_kv_heads=None, + causal=False, + enable_k_writeback=True, +): + """Build flash attention + KV cache module (RoPE applied on host). + + Args: + lk: Total K/V sequence length (default: 512) + lkp: K/V chunk size per tile (default: 64) + lq: Total Q sequence length (default: 512) + lqp: Q chunk size per launch iteration (default: 256) + dk: Key dimension (default: 64) + dv: Value dimension (default: 64) + num_q_tiles: Number of tiles to partition Q chunk into (default: 4) + num_cascade_stages: Number of cascade pipeline stages (default: 4) + num_heads: Number of attention heads (default: 2) + num_kv_heads: Number of key/value heads for grouped-query attention + (GQA). If None, defaults to num_heads (standard MHA). + causal: Whether to enable causal (autoregressive) masking. + """ + # Validate + assert lq % lqp == 0, f"lq ({lq}) must be divisible by lqp ({lqp})" + assert ( + lqp % num_q_tiles == 0 + ), f"lqp ({lqp}) must be divisible by num_q_tiles ({num_q_tiles})" + assert lk % lkp == 0, f"lk ({lk}) must be divisible by lkp ({lkp})" + assert lk % (lkp * num_cascade_stages) == 0, ( + f"lk ({lk}) must be divisible by lkp * num_cascade_stages " + f"({lkp * num_cascade_stages})" + ) + dk_tile = lkp + assert dk % dk_tile == 0, f"dk ({dk}) must be divisible by dk_tile/lkp ({dk_tile})" + dk_chunks = dk // dk_tile + dv_tile = lkp + assert dv % dv_tile == 0, f"dv ({dv}) must be divisible by dv_tile/lkp ({dv_tile})" + dv_chunks = dv // dv_tile + if causal: + assert lq == lk, f"Causal masking requires lq == lk, got lq={lq}, lk={lk}" + assert lqp // num_q_tiles == lkp, ( + f"Causal masking requires tile_size_q == lkp, got " + f"tile_size_q={lqp // num_q_tiles}, lkp={lkp}" + ) + + # Multi-head / GQA parameters + if num_kv_heads is None: + num_kv_heads = num_heads + assert num_kv_heads > 0, f"num_kv_heads must be positive, got {num_kv_heads}" + assert num_heads % num_kv_heads == 0, ( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + gqa_group_size = num_heads // num_kv_heads + + num_heads_per_unroll = 2 + assert num_heads % num_heads_per_unroll == 0, ( + f"num_heads ({num_heads}) must be divisible by " + f"num_heads_per_unroll ({num_heads_per_unroll})" + ) + num_head_groups = num_heads // num_heads_per_unroll + + bf16 = Type.parse("bf16") + i32 = IntegerType.get_signless(32) + index_type = IndexType.get() + + M = 8 # mmul_m = mmul_k = mmul_n + + # Derived parameters + num_lq_iters = lq // lqp + tile_size_q = lqp // num_q_tiles + num_chunks = lk // lkp + chunks_per_stage = num_chunks // num_cascade_stages + lk_per_stage = lkp * chunks_per_stage + + NQ = num_q_tiles + NS = num_cascade_stages + + # Memory spaces + l1_space = IntegerAttr.get(i32, 2) + l2_space = IntegerAttr.get(i32, 1) + + # L1 MemRefTypes (Q and K use dk_tile, not full dk) + q_l1_t = MemRefType.get([tile_size_q, dk_tile], bf16, memory_space=l1_space) + k_l1_t = MemRefType.get([lkp, dk_tile], bf16, memory_space=l1_space) + v_l1_t = MemRefType.get([lkp, dv_tile], bf16, memory_space=l1_space) + g_l1_2d = MemRefType.get([tile_size_q, lkp], bf16, memory_space=l1_space) + g_l1_1d = MemRefType.get([tile_size_q * lkp], bf16, memory_space=l1_space) + gp_l1_t = MemRefType.get([tile_size_q, dv_tile], bf16, memory_space=l1_space) + up_l1_t = MemRefType.get([tile_size_q, 1], bf16, memory_space=l1_space) + + # L2 MemRefTypes (QK relay uses dk_tile) + qk_l2_t = MemRefType.get([lkp, dk_tile], bf16, memory_space=l2_space) + v_l2_t = MemRefType.get([lkp, dv_tile], bf16, memory_space=l2_space) + gp_l2_t = MemRefType.get([lqp, dv_tile], bf16, memory_space=l2_space) + + # L3 MemRefTypes (3D with head dimension) + q_l3_t = MemRefType.get([num_heads, lq, dk], bf16) + k_l3_t = MemRefType.get([num_kv_heads, lk, dk], bf16) + # V and output L3 use transposed layout for contiguous dv_tile access: + # [heads * dv_chunks, seq, dv_tile] instead of [heads, seq, dv] + v_l3_t = MemRefType.get([num_kv_heads * dv_chunks, lk, dv_tile], bf16) + gp_l3_t = MemRefType.get([num_heads * dv_chunks, lq, dv_tile], bf16) + + # KV cache L3 types + k_cache_l3_t = MemRefType.get([num_kv_heads, lk, dk], bf16) + v_cache_l3_t = MemRefType.get([num_kv_heads * dv_chunks, lk, dv_tile], bf16) + + # External function declarations + def external_func(name, inputs, outputs=None, link_with=None, visibility="private"): + if outputs is None: + outputs = [] + func_type = FunctionType.get(inputs, outputs) + func = FuncOp(name=name, type=func_type, visibility=visibility) + func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + if link_with: + func.attributes["link_with"] = StringAttr.get(link_with) + return func + + external_func("zero_fill_g_bf16", [g_l1_1d], link_with="attn_npu2.o") + external_func("zero_fill_gp_bf16", [gp_l1_t], link_with="attn_npu2.o") + external_func("zero_fill_sp_bf16", [up_l1_t], link_with="attn_npu2.o") + external_func("neg_inf_fill_up_bf16", [up_l1_t], link_with="attn_npu2.o") + external_func( + "matmul_a_b_bf16", + [q_l1_t, k_l1_t, g_l1_1d], + link_with="attn_npu2.o", + ) + external_func( + "matmul_g_b_bf16", + [g_l1_1d, v_l1_t, gp_l1_t], + link_with="attn_npu2.o", + ) + external_func( + "fused_softmax", + [g_l1_1d, up_l1_t, up_l1_t, up_l1_t], + link_with="attn_npu2.o", + ) + external_func("maximum_up_u_bf16", [up_l1_t, up_l1_t], link_with="attn_npu2.o") + external_func( + "exp_up_minus_u", + [up_l1_t, up_l1_t, up_l1_t], + link_with="attn_npu2.o", + ) + external_func("mul_r_gp", [up_l1_t, gp_l1_t], link_with="attn_npu2.o") + external_func( + "accum_sp_r_s", + [up_l1_t, up_l1_t, up_l1_t], + link_with="attn_npu2.o", + ) + external_func( + "vector_copy_32elems", [i32, up_l1_t, up_l1_t], link_with="attn_npu2.o" + ) + external_func("copy_tile", [k_l1_t, q_l1_t], link_with="attn_npu2.o") + external_func("div_gp_sp", [up_l1_t, gp_l1_t], link_with="attn_npu2.o") + external_func("add_gp_g", [gp_l1_t, gp_l1_t], link_with="attn_npu2.o") + if causal: + external_func("apply_causal_mask", [g_l1_2d, i32, i32], link_with="attn_npu2.o") + + # ---------------------------------------------------------------- + # Channel declarations + # ---------------------------------------------------------------- + + # QK: per-stage through memtile (3D with head dimension) + for s in range(NS): + Channel( + f"QK2L1_{s}", + size=[num_heads_per_unroll, 1, 1], + broadcast_shape=[num_heads_per_unroll, 1, NQ], + ) + Channel(f"QKIn_{s}", size=[num_heads_per_unroll]) + + # V: per-stage through memtile (3D with head dimension) + for s in range(NS): + Channel( + f"V2L1_{s}", + size=[num_heads_per_unroll, 1, 1], + broadcast_shape=[num_heads_per_unroll, 1, NQ], + ) + Channel(f"VIn_{s}", size=[num_heads_per_unroll]) + + # Cascade: 2D per-segment (shared within each segment instance) + channel("cascade_gp", size=[NQ, NS - 1], channel_type="cascade") + channel("cascade_up", size=[NQ, NS - 1], channel_type="cascade") + channel("cascade_sp", size=[NQ, NS - 1], channel_type="cascade") + + # Output: L1-to-L2 gather, then L2-to-L3 + Channel("Gp2L2", size=[NQ, 1]) + Channel("GpOut", size=[num_heads_per_unroll]) + + # K cache write-back: tile-level L1→L3 (tx=0 tiles write K to shim) + if enable_k_writeback: + Channel("KWB", size=[num_heads_per_unroll, NS, 1]) + + # ---------------------------------------------------------------- + # Main function: fused RoPE + attention + KV cache + # ---------------------------------------------------------------- + func_args = [q_l3_t, k_l3_t, v_l3_t, gp_l3_t, k_cache_l3_t, v_cache_l3_t] + + @FuncOp.from_py_func(*func_args) + def attention_bf16(*func_params): + q_in, k_in, v_in, gp_out, k_cache, v_cache = func_params + c1 = ConstantOp(index_type, 1) + c_lq_iters = ConstantOp(index_type, num_lq_iters) + c_num_head_groups = ConstantOp(index_type, num_head_groups) + + if dv_chunks > 1: + c_dv_chunks = ConstantOp(index_type, dv_chunks) + launch_sizes = [c_lq_iters, c_num_head_groups, c_dv_chunks] + else: + launch_sizes = [c_lq_iters, c_num_head_groups] + + launch_operands = [q_in, k_in, v_in, gp_out, k_cache, v_cache] + + @launch( + operands=launch_operands, + sizes=launch_sizes, + ) + def launch_body(*launch_args): + if dv_chunks > 1: + lx, ly, lz, lsx, lsy, lsz, q, k, v, gp, kcache, vcache = launch_args + else: + lx, ly, lsx, lsy, q, k, v, gp, kcache, vcache = launch_args + lz = ConstantOp(index_type, 0) + + # Compute Q offset from launch iteration index + affine_map_q_launch = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(lqp * dk), + ) + ], + ) + q_launch_off = affine_apply(affine_map_q_launch, [lx]) + + # Output launch offset (transposed layout uses dv_tile, not dv) + affine_map_out_launch = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(lqp * dv_tile), + ) + ], + ) + out_launch_off = affine_apply(affine_map_out_launch, [lx]) + + # Compute head base from head group index (ly) + affine_map_head_base = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(num_heads_per_unroll), + ) + ], + ) + head_base = affine_apply(affine_map_head_base, [ly]) + + # Offset maps for one head's worth of Q/K/V/output data + affine_map_head_q = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(lq * dk), + ) + ], + ) + affine_map_head_k = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(lk * dk), + ) + ], + ) + affine_map_head_v_dv = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_mul( + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(dv_chunks), + ), + AffineSymbolExpr.get(1), + ), + AffineConstantExpr.get(lk * dv_tile), + ) + ], + ) + affine_map_head_out_dv = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_mul( + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(dv_chunks), + ), + AffineSymbolExpr.get(1), + ), + AffineConstantExpr.get(lq * dv_tile), + ) + ], + ) + + # s0 + s1 + affine_map_add = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(0), + AffineSymbolExpr.get(1), + ) + ], + ) + + # head_1 = head_base + 1 + affine_map_plus1 = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(1), + ) + ], + ) + + # GQA head map + if gqa_group_size > 1: + affine_map_kv_head = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_floor_div( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(gqa_group_size), + ) + ], + ) + + for head_local in range(num_heads_per_unroll): + if head_local == 0: + head_idx = head_base + else: + head_idx = affine_apply(affine_map_plus1, [head_base]) + + if gqa_group_size == 1: + kv_head_idx = head_idx + else: + kv_head_idx = affine_apply( + affine_map_kv_head, + [head_idx], + ) + + head_q_off = affine_apply(affine_map_head_q, [head_idx]) + head_k_off = affine_apply(affine_map_head_k, [kv_head_idx]) + head_v_off = affine_apply(affine_map_head_v_dv, [kv_head_idx, lz]) + head_out_off = affine_apply(affine_map_head_out_dv, [head_idx, lz]) + + head_offset_idx = ConstantOp(index_type, head_local) + + q_combined = affine_apply(affine_map_add, [head_q_off, q_launch_off]) + out_combined = affine_apply( + affine_map_add, [head_out_off, out_launch_off] + ) + + # ---------------------------------------------------------- + # Q puts: bulk Q data to QKIn (pre-RoPE'd when enabled) + # ---------------------------------------------------------- + for stage in range(NS): + ChannelPut( + f"QKIn_{stage}", + q, + indices=[head_offset_idx], + offsets=[0, q_combined], + sizes=[NQ, dk_chunks, tile_size_q, dk_tile], + strides=[tile_size_q * dk, dk_tile, dk, 1], + ) + + # ---------------------------------------------------------- + # K puts: bulk K data to QKIn (pre-RoPE'd when enabled) + # ---------------------------------------------------------- + k_stage_off_val = stage * lk_per_stage * dk + k_combined = affine_apply( + affine_map_add, + [head_k_off, ConstantOp(index_type, k_stage_off_val)], + ) + ChannelPut( + f"QKIn_{stage}", + k, + indices=[head_offset_idx], + offsets=[0, k_combined], + sizes=[chunks_per_stage, dk_chunks, lkp, dk_tile], + strides=[lkp * dk, dk_tile, dk, 1], + ) + + # ---------------------------------------------------------- + # V puts: bulk V data to VIn + # ---------------------------------------------------------- + v_stage_off_val = stage * lk_per_stage * dv_tile + v_combined = affine_apply( + affine_map_add, + [head_v_off, ConstantOp(index_type, v_stage_off_val)], + ) + ChannelPut( + f"VIn_{stage}", + v, + indices=[head_offset_idx], + offsets=[0, 0, v_combined], + sizes=[chunks_per_stage, lkp, dv_tile], + strides=[lkp * dv_tile, dv_tile, 1], + ) + + # ---------------------------------------------------------- + # K cache get (L1→L3 from tx=0 tiles per stage) + # ---------------------------------------------------------- + if enable_k_writeback: + # GQA: skip duplicate writes for heads sharing same KV head + is_first_in_gqa_group = ( + gqa_group_size == 1 or head_local % gqa_group_size == 0 + ) + if is_first_in_gqa_group: + for stage in range(NS): + k_stage_off_val_wb = stage * lk_per_stage * dk + k_combined_wb = affine_apply( + affine_map_add, + [ + head_k_off, + ConstantOp(index_type, k_stage_off_val_wb), + ], + ) + ChannelGet( + "KWB", + kcache, + indices=[ + head_offset_idx, + ConstantOp(index_type, stage), + ConstantOp(index_type, 0), + ], + offsets=[0, k_combined_wb], + sizes=[ + chunks_per_stage, + dk_chunks, + lkp, + dk_tile, + ], + strides=[lkp * dk, dk_tile, dk, 1], + ) + + # ---------------------------------------------------------- + # Output get (after K cache, matches segment ordering) + # ---------------------------------------------------------- + ChannelGet( + "GpOut", + gp, + indices=[head_offset_idx], + offsets=[out_combined], + sizes=[lqp * dv_tile], + strides=[1], + ) + + # ---------------------------------------------------------- + # Segment: unrolled over heads + # ---------------------------------------------------------- + c_num_heads_unroll = ConstantOp(index_type, num_heads_per_unroll) + c1_seg = ConstantOp(index_type, 1) + + @segment( + name="attn_seg", + operands=[lx], + sizes=[c_num_heads_unroll, c1_seg], + ) + def segment_body(seg_x, seg_y, seg_sx, seg_sy, seg_lx): + # L2 allocations for QK and V (per-stage) and output + qk_l2_bufs = [AllocOp(qk_l2_t, [], []) for _ in range(NS)] + v_l2_bufs = [AllocOp(v_l2_t, [], []) for _ in range(NS)] + gp_l2 = AllocOp(gp_l2_t, [], []) + + # L1 allocations passed to herd + q_saved_bufs = [AllocOp(q_l1_t, [], []) for _ in range(dk_chunks)] + # K buffer: reused across chunks in the chunk loop. + # Only one K buffer needed since K is processed one chunk + # at a time (receive K, matmul, write-back). + k_saved_bufs = [AllocOp(k_l1_t, [], [])] + if enable_k_writeback: + kwb_l1 = AllocOp(k_l1_t, [], []) + v_l1 = AllocOp(v_l1_t, [], []) + g_l1 = AllocOp(g_l1_2d, [], []) + gp_l1 = AllocOp(gp_l1_t, [], []) + up_l1 = AllocOp(up_l1_t, [], []) + sp_l1 = AllocOp(up_l1_t, [], []) + if causal: + ctr_size = 4 if dv_chunks > 1 else 3 + ctr_t = MemRefType.get([ctr_size], i32, memory_space=l1_space) + causal_ctr = AllocOp(ctr_t, [], []) + + c_nq = ConstantOp(index_type, NQ) + c_ns = ConstantOp(index_type, NS) + c0_seg = ConstantOp(index_type, 0) + c_chunks_s = ConstantOp(index_type, chunks_per_stage) + + # ---------------------------------------------------------- + # Per-stage relay: Q tiles, then per-chunk K + K-WB + V. + # ---------------------------------------------------------- + c_q_relay = ConstantOp(index_type, NQ * dk_chunks) + c_dk_chunks = ConstantOp(index_type, dk_chunks) + + # Tiling transform for QK2L1 relay + qk_relay_sizes = [dk_tile // M, lkp // M, M, M] + qk_relay_strides = [M, dk_tile * M, dk_tile, 1] + + for stage in range(NS): + # === Q phase: Q tiles === + for qt_iter in scf_range(0, c_q_relay, 1): + ChannelGet( + f"QKIn_{stage}", + qk_l2_bufs[stage].result, + indices=[seg_x], + ) + ChannelPut( + f"QK2L1_{stage}", + qk_l2_bufs[stage].result, + indices=[seg_x, c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=qk_relay_sizes, + strides=qk_relay_strides, + ) + yield_([]) + + # === Per-chunk: K data, V === + for chunk_iter in scf_range(0, c_chunks_s, 1): + # K data relay: QKIn → QK2L1 + for k_iter in scf_range(0, c_dk_chunks, 1): + ChannelGet( + f"QKIn_{stage}", + qk_l2_bufs[stage].result, + indices=[seg_x], + ) + ChannelPut( + f"QK2L1_{stage}", + qk_l2_bufs[stage].result, + indices=[seg_x, c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=qk_relay_sizes, + strides=qk_relay_strides, + ) + yield_([]) + + # V relay: VIn → V2L1 + ChannelGet( + f"VIn_{stage}", + v_l2_bufs[stage].result, + indices=[seg_x], + ) + ChannelPut( + f"V2L1_{stage}", + v_l2_bufs[stage].result, + indices=[seg_x, c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=[dv_tile // M, lkp // M, M, M], + strides=[M, dv_tile * M, dv_tile, 1], + ) + yield_([]) + + # Output gather from ty=0 tiles (after K write-back) + affine_map_col = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_size_q), + ) + ], + ) + par_out = scf.ForallOp(lower_bounds=[0], upper_bounds=[NQ], steps=[1]) + with InsertionPoint(par_out.body): + apply_off = affine_apply( + affine_map_col, + [par_out.induction_variables[0]], + ) + ChannelGet( + "Gp2L2", + gp_l2.result, + indices=[par_out.induction_variables[0], 0], + offsets=[apply_off, 0], + sizes=[tile_size_q, dv_tile], + strides=[dv_tile, 1], + ) + scf.InParallelOp() + + # Output: L2-to-L3 + ChannelPut("GpOut", gp_l2.result, indices=[seg_x]) + + # ---------------------------------------------------------- + # Herd: [NQ, NS] + # ---------------------------------------------------------- + herd_operands = ( + q_saved_bufs + + k_saved_bufs + + [ + v_l1, + g_l1, + gp_l1, + up_l1, + sp_l1, + seg_x, + ] + ) + if enable_k_writeback: + herd_operands.append(kwb_l1) + if causal: + herd_operands.append(causal_ctr) + + @herd( + name="herd_0", + sizes=[c_nq, c_ns], + operands=herd_operands, + link_with="attn_npu2.o", + ) + def herd_body(tx, ty, hsx, hsy, *all_args): + # Unpack: dk_chunks Q bufs, 1 K buf (reused per chunk), + # v, g, gp, up, sp, seg_x, [kwb_buf], [causal_ctr] + q_bufs = list(all_args[:dk_chunks]) + qk_tmp = all_args[dk_chunks] # single K/QK temp buffer + base = dk_chunks + 1 + v = all_args[base] + g = all_args[base + 1] + gp = all_args[base + 2] + up_buf = all_args[base + 3] + sp_buf = all_args[base + 4] + h_seg_x = all_args[base + 5] + next_idx = base + 6 + kwb_buf = None + if enable_k_writeback: + kwb_buf = all_args[next_idx] + next_idx += 1 + counter_buf = all_args[next_idx] if causal else None + + # Precompute affine sets for per-stage dispatch + s0 = AffineSymbolExpr.get(0) + s1 = AffineSymbolExpr.get(1) + c_ns_m1 = AffineConstantExpr.get(NS - 1) + stage_sets = [] + for s in range(NS): + cs = AffineConstantExpr.get(s) + stage_sets.append( + IntegerSet.get( + 0, + 2, + [s0, s1 - cs], + [False, True], + ) + ) + + # === INIT PHASE === + CallOp([], "zero_fill_gp_bf16", [gp]) + CallOp([], "zero_fill_sp_bf16", [sp_buf]) + CallOp([], "neg_inf_fill_up_bf16", [up_buf]) + + # === CAUSAL COUNTER INIT === + if causal: + c0_ctr = ConstantOp(index_type, 0) + c1_ctr = ConstantOp(index_type, 1) + c2_ctr = ConstantOp(index_type, 2) + c3_ctr = ConstantOp(index_type, 3) if dv_chunks > 1 else None + boot_flag = load(counter_buf, [c1_ctr]) + is_first = arith.CmpIOp( + arith.CmpIPredicate.eq, + boot_flag, + ConstantOp(i32, 0), + ) + if_first = scf.IfOp(is_first) + with InsertionPoint(if_first.then_block): + store(ConstantOp(i32, 0), counter_buf, [c0_ctr]) + store(ConstantOp(i32, 1), counter_buf, [c1_ctr]) + store(ConstantOp(i32, 0), counter_buf, [c2_ctr]) + if dv_chunks > 1: + store(ConstantOp(i32, 0), counter_buf, [c3_ctr]) + scf.YieldOp([]) + + # === Q SELECTIVE CAPTURE === + # Phase 1: Receive all Q data tiles, copy to saved bufs. + for qt in range(NQ): + for dk_c in range(dk_chunks): + # Receive Q tile → qk_tmp + for s in range(NS): + if_qk_q = affine.AffineIfOp( + stage_sets[s], + cond_operands=[tx, ty], + ) + with InsertionPoint(if_qk_q.then_block): + ChannelGet( + f"QK2L1_{s}", + qk_tmp, + indices=[h_seg_x, ty, tx], + ) + affine.AffineYieldOp([]) + # Copy Q to saved buffer if tx==qt + cmp = arith.CmpIOp( + arith.CmpIPredicate.eq, + arith.IndexCastOp(i32, tx), + arith.ConstantOp(i32, qt), + ) + if_cap = scf.IfOp(cmp) + with InsertionPoint(if_cap.then_block): + CallOp([], "copy_tile", [qk_tmp, q_bufs[dk_c]]) + scf.YieldOp([]) + + # === CHUNK LOOP: K receive, matmul, V, softmax === + # K tiles are received one chunk at a time (not bulk). + c_chunks_h = ConstantOp(index_type, chunks_per_stage) + for chunk_iter in scf_range(0, c_chunks_h, 1): + # 1. Zero fill G + g1d = CollapseShapeOp(g_l1_1d, g, [[0, 1]]) + CallOp([], "zero_fill_g_bf16", [g1d]) + + # 2. Receive K tile(s) + matmul + for dk_c in range(dk_chunks): + # Receive K → qk_tmp + for s in range(NS): + if_qk_k = affine.AffineIfOp( + stage_sets[s], + cond_operands=[tx, ty], + ) + with InsertionPoint(if_qk_k.then_block): + ChannelGet( + f"QK2L1_{s}", + qk_tmp, + indices=[h_seg_x, ty, tx], + ) + affine.AffineYieldOp([]) + + # Matmul Q @ K → G + CallOp( + [], + "matmul_a_b_bf16", + [q_bufs[dk_c], qk_tmp, g1d], + ) + + # K write-back (tx=0 only, L1→L3) + # Copy K data to a separate staging buffer first + # to avoid DMA race with the next chunk's K receive + # into qk_tmp. Data in qk_tmp is in tiled [M,M] + # format from QK2L1 relay. Un-tile to row-major + # [lkp, dk_tile] for the K cache. + if enable_k_writeback: + cmp_tx0 = arith.CmpIOp( + arith.CmpIPredicate.eq, + arith.IndexCastOp(i32, tx), + arith.ConstantOp(i32, 0), + ) + if_tx0 = scf.IfOp(cmp_tx0) + with InsertionPoint(if_tx0.then_block): + CallOp([], "copy_tile", [qk_tmp, kwb_buf]) + for s in range(NS): + if_kwb = affine.AffineIfOp( + stage_sets[s], + cond_operands=[tx, ty], + ) + with InsertionPoint(if_kwb.then_block): + ChannelPut( + "KWB", + kwb_buf, + indices=[ + h_seg_x, + ConstantOp(index_type, s), + tx, + ], + offsets=[0, 0, 0, 0], + sizes=[ + lkp // M, + M, + dk_tile // M, + M, + ], + strides=[ + M * M, + M, + lkp * M, + 1, + ], + ) + affine.AffineYieldOp([]) + scf.YieldOp([]) + + # 3. V get via affine.if per stage + for s in range(NS): + if_v = affine.AffineIfOp( + stage_sets[s], + cond_operands=[tx, ty], + ) + with InsertionPoint(if_v.then_block): + ChannelGet( + f"V2L1_{s}", + v, + indices=[h_seg_x, ty, tx], + ) + affine.AffineYieldOp([]) + + # 4b. Apply causal mask + if causal: + c_cps_i32 = ConstantOp(i32, chunks_per_stage) + ty_i32 = arith.IndexCastOp(i32, ty).result + chunk_i32 = arith.IndexCastOp( + i32, + chunk_iter, + ).result + kv_base = arith.MulIOp(ty_i32, c_cps_i32) + kv_block = arith.AddIOp( + kv_base.result, + chunk_i32, + ) + q_base = load(counter_buf, [c0_ctr]) + tx_i32 = arith.IndexCastOp(i32, tx).result + q_block = arith.AddIOp(q_base, tx_i32) + CallOp( + [], + "apply_causal_mask", + [g, q_block.result, kv_block.result], + ) + + # 5. Softmax + accumulate + s_tmp = AllocOp(up_l1_t, [], []) + r_tmp = AllocOp(up_l1_t, [], []) + CallOp( + [], + "fused_softmax", + [g1d, up_buf, s_tmp.result, r_tmp.result], + ) + CallOp([], "mul_r_gp", [r_tmp.result, gp]) + CallOp([], "matmul_g_b_bf16", [g1d, v, gp]) + + c0_i32 = ConstantOp(i32, 0) + CallOp( + [], + "accum_sp_r_s", + [sp_buf, r_tmp.result, s_tmp.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32, s_tmp.result, sp_buf], + ) + DeallocOp(s_tmp) + DeallocOp(r_tmp) + yield_([]) + + # === CASCADE MERGE === + set_first_stage = IntegerSet.get( + 0, 2, [s0, s1 - c_ns_m1], [False, True] + ) + set_middle_stage = IntegerSet.get( + 0, + 2, + [ + AffineExpr.get_add(s1, AffineConstantExpr.get(-1)), + AffineExpr.get_add( + AffineConstantExpr.get(NS - 2), + AffineExpr.get_mul(s1, AffineConstantExpr.get(-1)), + ), + s0, + AffineExpr.get_add( + AffineConstantExpr.get(NQ - 1), + AffineExpr.get_mul(s0, AffineConstantExpr.get(-1)), + ), + ], + [False, False, False, False], + ) + c1_h = ConstantOp(index_type, 1) + + # Last stage (ty == NS-1): send cascade down + if_last = affine.AffineIfOp( + set_first_stage, + cond_operands=[tx, ty], + has_else=True, + ) + with InsertionPoint(if_last.then_block): + subi_l = arith.SubIOp(ty, c1_h) + ChannelPut("cascade_gp", gp, indices=[tx, subi_l]) + ChannelPut("cascade_up", up_buf, indices=[tx, subi_l]) + ChannelPut("cascade_sp", sp_buf, indices=[tx, subi_l]) + affine.AffineYieldOp([]) + + with InsertionPoint(if_last.else_block): + if_mid = affine.AffineIfOp( + set_middle_stage, + cond_operands=[tx, ty], + has_else=True, + ) + with InsertionPoint(if_mid.then_block): + gp_c = AllocOp(gp_l1_t, [], []) + up_c = AllocOp(up_l1_t, [], []) + sp_c = AllocOp(up_l1_t, [], []) + ChannelGet("cascade_gp", gp_c.result, indices=[tx, ty]) + ChannelGet("cascade_up", up_c.result, indices=[tx, ty]) + ChannelGet("cascade_sp", sp_c.result, indices=[tx, ty]) + up_s = AllocOp(up_l1_t, [], []) + c0m = ConstantOp(i32, 0) + CallOp( + [], "vector_copy_32elems", [c0m, up_buf, up_s.result] + ) + CallOp([], "maximum_up_u_bf16", [up_c.result, up_buf]) + rc = AllocOp(up_l1_t, [], []) + CallOp( + [], "exp_up_minus_u", [up_c.result, up_buf, rc.result] + ) + rl = AllocOp(up_l1_t, [], []) + CallOp( + [], "exp_up_minus_u", [up_s.result, up_buf, rl.result] + ) + CallOp([], "mul_r_gp", [rc.result, gp_c.result]) + CallOp([], "mul_r_gp", [rl.result, gp]) + CallOp([], "add_gp_g", [gp, gp_c.result]) + st = AllocOp(up_l1_t, [], []) + CallOp([], "zero_fill_sp_bf16", [st.result]) + CallOp( + [], "accum_sp_r_s", [sp_c.result, rc.result, st.result] + ) + CallOp([], "accum_sp_r_s", [sp_buf, rl.result, st.result]) + CallOp( + [], "vector_copy_32elems", [c0m, st.result, sp_c.result] + ) + subi_m = arith.SubIOp(ty, c1_h) + ChannelPut("cascade_gp", gp_c.result, indices=[tx, subi_m]) + ChannelPut("cascade_up", up_buf, indices=[tx, subi_m]) + ChannelPut("cascade_sp", sp_c.result, indices=[tx, subi_m]) + DeallocOp(gp_c) + DeallocOp(up_c) + DeallocOp(sp_c) + DeallocOp(up_s) + DeallocOp(rc) + DeallocOp(rl) + DeallocOp(st) + affine.AffineYieldOp([]) + + with InsertionPoint(if_mid.else_block): + # First stage (ty == 0): cascade in, merge, div, output + gp_c2 = AllocOp(gp_l1_t, [], []) + up_c2 = AllocOp(up_l1_t, [], []) + sp_c2 = AllocOp(up_l1_t, [], []) + ChannelGet("cascade_gp", gp_c2.result, indices=[tx, ty]) + ChannelGet("cascade_up", up_c2.result, indices=[tx, ty]) + ChannelGet("cascade_sp", sp_c2.result, indices=[tx, ty]) + up_s2 = AllocOp(up_l1_t, [], []) + c0f = ConstantOp(i32, 0) + CallOp( + [], "vector_copy_32elems", [c0f, up_buf, up_s2.result] + ) + CallOp([], "maximum_up_u_bf16", [up_c2.result, up_buf]) + rc2 = AllocOp(up_l1_t, [], []) + CallOp( + [], "exp_up_minus_u", [up_c2.result, up_buf, rc2.result] + ) + rl2 = AllocOp(up_l1_t, [], []) + CallOp( + [], "exp_up_minus_u", [up_s2.result, up_buf, rl2.result] + ) + CallOp([], "mul_r_gp", [rc2.result, gp_c2.result]) + CallOp([], "mul_r_gp", [rl2.result, gp]) + CallOp([], "add_gp_g", [gp, gp_c2.result]) + st2 = AllocOp(up_l1_t, [], []) + CallOp([], "zero_fill_sp_bf16", [st2.result]) + CallOp( + [], + "accum_sp_r_s", + [sp_c2.result, rc2.result, st2.result], + ) + CallOp([], "accum_sp_r_s", [sp_buf, rl2.result, st2.result]) + CallOp( + [], + "vector_copy_32elems", + [c0f, st2.result, sp_c2.result], + ) + CallOp([], "div_gp_sp", [sp_c2.result, gp_c2.result]) + c0_out = ConstantOp(index_type, 0) + ChannelPut( + "Gp2L2", + gp_c2.result, + indices=[tx, c0_out], + offsets=[0, 0, 0, 0], + sizes=[ + tile_size_q // M, + M, + dv_tile // M, + M, + ], + strides=[ + M * M, + M, + tile_size_q * M, + 1, + ], + ) + DeallocOp(gp_c2) + DeallocOp(up_c2) + DeallocOp(sp_c2) + DeallocOp(up_s2) + DeallocOp(rc2) + DeallocOp(rl2) + DeallocOp(st2) + affine.AffineYieldOp([]) + affine.AffineYieldOp([]) + + # === CAUSAL COUNTER INCREMENT === + if causal: + + def _emit_counter_increment(): + head_cur = load(counter_buf, [c2_ctr]) + c1_i32_inc = ConstantOp(i32, 1) + head_next = arith.AddIOp(head_cur, c1_i32_inc) + total_hg = ConstantOp(i32, num_head_groups) + wrapped = arith.CmpIOp( + arith.CmpIPredicate.sge, + head_next.result, + total_hg, + ) + if_wrap = scf.IfOp(wrapped) + with InsertionPoint(if_wrap.then_block): + q_cur = load(counter_buf, [c0_ctr]) + c_nq_i32 = ConstantOp(i32, NQ) + q_next = arith.AddIOp(q_cur, c_nq_i32) + store(q_next.result, counter_buf, [c0_ctr]) + store(ConstantOp(i32, 0), counter_buf, [c2_ctr]) + scf.YieldOp([]) + not_wrapped = arith.CmpIOp( + arith.CmpIPredicate.slt, + head_next.result, + total_hg, + ) + if_no_wrap = scf.IfOp(not_wrapped) + with InsertionPoint(if_no_wrap.then_block): + store(head_next.result, counter_buf, [c2_ctr]) + scf.YieldOp([]) + + if dv_chunks > 1: + dv_iter_cur = load(counter_buf, [c3_ctr]) + c_dv_last_i32 = ConstantOp(i32, dv_chunks - 1) + is_last_dv = arith.CmpIOp( + arith.CmpIPredicate.sge, + dv_iter_cur, + c_dv_last_i32, + ) + if_last_dv = scf.IfOp(is_last_dv) + with InsertionPoint(if_last_dv.then_block): + _emit_counter_increment() + store( + ConstantOp(i32, 0), + counter_buf, + [c3_ctr], + ) + scf.YieldOp([]) + not_last_dv = arith.CmpIOp( + arith.CmpIPredicate.slt, + dv_iter_cur, + c_dv_last_i32, + ) + if_not_last = scf.IfOp(not_last_dv) + with InsertionPoint(if_not_last.then_block): + c1_i32_dv = ConstantOp(i32, 1) + dv_next = arith.AddIOp(dv_iter_cur, c1_i32_dv) + store( + dv_next.result, + counter_buf, + [c3_ctr], + ) + scf.YieldOp([]) + else: + _emit_counter_increment() + + # Deallocs for segment-level buffers + for q_buf in q_saved_bufs: + DeallocOp(q_buf) + DeallocOp(k_saved_bufs[0]) + if enable_k_writeback: + DeallocOp(kwb_l1) + DeallocOp(v_l1) + DeallocOp(g_l1) + DeallocOp(gp_l1) + DeallocOp(up_l1) + DeallocOp(sp_l1) + for stage in range(NS): + DeallocOp(v_l2_bufs[stage]) + for stage in range(NS): + DeallocOp(qk_l2_bufs[stage]) + DeallocOp(gp_l2) + if causal: + DeallocOp(causal_ctr) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="attn_npu2.py", + description="Fused RoPE + flash attention + KV cache write-back", + ) + parser.add_argument( + "-p", + "--print-module-only", + action="store_true", + help="Print MLIR module and exit", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose output", + ) + parser.add_argument( + "--lk", + type=int, + default=512, + help="Total K/V sequence length (default: 512)", + ) + parser.add_argument( + "--lq", + type=int, + default=512, + help="Total Q sequence length (default: 512)", + ) + parser.add_argument( + "--lqp", + type=int, + default=256, + help="Q chunk size per launch iteration (default: 256)", + ) + parser.add_argument( + "--lkp", + type=int, + default=64, + help="K/V chunk size per tile (default: 64)", + ) + parser.add_argument( + "--dk", + type=int, + default=64, + help="Key dimension (default: 64). Must be divisible by lkp.", + ) + parser.add_argument( + "--dv", + type=int, + default=64, + help="Value dimension (default: 64). Must be divisible by lkp.", + ) + parser.add_argument( + "--num-cascade-stages", + type=int, + default=4, + help="Number of cascade pipeline stages (default: 4)", + ) + parser.add_argument( + "--num-heads", + type=int, + default=2, + help="Number of attention heads (default: 2)", + ) + parser.add_argument( + "--num-kv-heads", + type=int, + default=None, + help="Number of KV heads (default: num_heads for MHA, < num_heads for GQA)", + ) + parser.add_argument( + "--compile-mode", + type=str, + default="compile-and-run", + choices=["compile-only", "compile-and-run"], + help="Compilation mode (default: compile-and-run)", + ) + parser.add_argument( + "--output-format", + type=str, + default="elf", + choices=["xclbin", "elf"], + help="Output format (default: elf)", + ) + parser.add_argument( + "--causal", + action="store_true", + help="Enable causal masking (autoregressive attention)", + ) + parser.add_argument( + "--no-k-writeback", + action="store_true", + help="Disable K cache write-back (for debugging)", + ) + parser.add_argument( + "--no-rope", + action="store_true", + help="Disable RoPE application (for debugging data flow)", + ) + args = parser.parse_args() + + lk = args.lk + lkp = args.lkp + lq = args.lq + lqp = args.lqp + dk = args.dk + dv = args.dv + num_cascade_stages = args.num_cascade_stages + num_q_tiles = 4 + num_heads = args.num_heads + num_kv_heads = args.num_kv_heads if args.num_kv_heads is not None else num_heads + causal = args.causal + enable_k_writeback = not args.no_k_writeback + enable_rope = not args.no_rope + gqa_group_size = num_heads // num_kv_heads + + mlir_module = build_module( + lk=lk, + lkp=lkp, + lq=lq, + lqp=lqp, + dk=dk, + dv=dv, + num_q_tiles=num_q_tiles, + num_cascade_stages=num_cascade_stages, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + causal=causal, + enable_k_writeback=enable_k_writeback, + ) + + if args.print_module_only: + print(mlir_module) + exit(0) + + from air.backend.xrt_runner import XRTRunner + from air.backend.xrt import XRTBackend + from ml_dtypes import bfloat16 + + INPUT_DATATYPE = OUTPUT_DATATYPE = bfloat16 + rng = np.random.default_rng(42) + val_range = 4.0 + input_q = rng.uniform(0, val_range, (num_heads, lq, dk)).astype(INPUT_DATATYPE) + input_k = rng.uniform(0, val_range, (num_kv_heads, lk, dk)).astype(INPUT_DATATYPE) + input_v_orig = rng.uniform(0, val_range, (num_kv_heads, lk, dv)).astype( + INPUT_DATATYPE + ) + # Transpose V to [num_kv_heads * dv_chunks, lk, dv_tile] for contiguous access + dv_chunks_host = dv // lkp + input_v = ( + input_v_orig.reshape(num_kv_heads, lk, dv_chunks_host, lkp) + .transpose(0, 2, 1, 3) + .reshape(num_kv_heads * dv_chunks_host, lk, lkp) + .copy() + ) + + # ================================================================ + # Generate RoPE LUT: interleaved [cos, sin, cos, sin, ...] per row + # ================================================================ + THETA = 10000.0 + rope_seq_len = max(lq, lk) + rope_lut_f32 = np.zeros((rope_seq_len, dk), dtype=np.float32) + for r in range(rope_seq_len): + for i in range(dk // 2): + freq = 1.0 / (THETA ** (2.0 * i / dk)) + angle = r * freq + rope_lut_f32[r, 2 * i] = np.cos(angle) + rope_lut_f32[r, 2 * i + 1] = np.sin(angle) + rope_lut_input = rope_lut_f32.astype(INPUT_DATATYPE) + + # ================================================================ + # Apply RoPE to Q and K for reference computation + # ================================================================ + def apply_rope_ref(x, lut_slice): + """Apply RoPE rotation. x: [seq, dk], lut_slice: [seq, dk].""" + x_f = x.astype(np.float32) + lut_f = lut_slice.astype(np.float32) + x_even = x_f[:, 0::2] + x_odd = x_f[:, 1::2] + cos_v = lut_f[:, 0::2] + sin_v = lut_f[:, 1::2] + out = np.zeros_like(x_f) + out[:, 0::2] = x_even * cos_v - x_odd * sin_v + out[:, 1::2] = x_even * sin_v + x_odd * cos_v + return out.astype(x.dtype) + + # RoPE'd Q and K (host-side pre-rotation when enabled) + if enable_rope: + q_roped = np.zeros_like(input_q) + for h in range(num_heads): + q_roped[h] = apply_rope_ref(input_q[h], rope_lut_input[:lq, :dk]) + k_roped = np.zeros_like(input_k) + for kv_h in range(num_kv_heads): + k_roped[kv_h] = apply_rope_ref(input_k[kv_h], rope_lut_input[:lk, :dk]) + else: + q_roped = input_q.copy() + k_roped = input_k.copy() + + # NPU receives pre-RoPE'd Q and K when RoPE is enabled + npu_input_q = q_roped + npu_input_k = k_roped + + # Reference attention using RoPE'd Q and K + inv_sqrt_dk = 1.0 / sqrt(dk) + sdpa_output = np.zeros((num_heads, lq, dv), dtype=OUTPUT_DATATYPE) + for h in range(num_heads): + kv_h = h // gqa_group_size + Qf = q_roped[h].astype(np.float32) + Kf = k_roped[kv_h].astype(np.float32) + Vf = input_v_orig[kv_h].astype(np.float32) + scores = Qf @ Kf.T * inv_sqrt_dk + if causal: + mask = np.triu(np.ones(scores.shape, dtype=bool), k=1) + scores = np.where(mask, -1e9, scores) + mx = np.max(scores, axis=-1, keepdims=True) + P = np.exp(scores - mx) + P = P / np.sum(P, axis=-1, keepdims=True) + sdpa_output[h] = (P @ Vf).astype(OUTPUT_DATATYPE) + + # Transpose expected output to match transposed L3 layout + sdpa_output_transposed = ( + sdpa_output.reshape(num_heads, lq, dv_chunks_host, lkp) + .transpose(0, 2, 1, 3) + .reshape(num_heads * dv_chunks_host, lq, lkp) + .copy() + ) + + tiling = [1, 1, 1] if dv_chunks_host > 1 else [1, 1] + backend_opts = dict( + omit_while_true_loop=False, + omit_pingpong="all", + verbose=args.verbose, + runtime_loop_tiling_sizes=tiling, + output_format=args.output_format, + instance_name="attention_bf16", + target_device="npu2", + ) + + # K cache expected output: RoPE'd K + expected_k_cache = k_roped.copy() + + if args.compile_mode == "compile-and-run": + import filelock, tempfile + + backend = XRTBackend(**backend_opts) + # 3 output buffers: attention, K cache, V cache (V cache unused by NPU) + v_cache_placeholder = np.zeros_like(input_v) + expected_outputs = [ + sdpa_output_transposed, + expected_k_cache, + v_cache_placeholder, + ] + output_placeholders = [np.zeros(o.shape, o.dtype) for o in expected_outputs] + # NPU receives pre-RoPE'd Q and K (RoPE applied on host when enabled) + input_list = [npu_input_q, npu_input_k, input_v] + num_inputs = len(input_list) + expanded_inputs = input_list + output_placeholders + + compiled_module = backend.compile(mlir_module) + with filelock.FileLock(os.path.join(tempfile.gettempdir(), "npu.lock")): + module_function = backend.load(compiled_module) + actual_outputs = module_function(*expanded_inputs) + backend.unload() + + # Remove input slots + actual_outputs = list(actual_outputs[num_inputs:]) + + failed = False + + # --- Output 0: Attention (SDPA) output --- + attn_actual = ( + actual_outputs[0].reshape(sdpa_output_transposed.shape).astype(np.float32) + ) + attn_expected = sdpa_output_transposed.astype(np.float32) + attn_corr = float( + np.corrcoef(attn_actual.flatten(), attn_expected.flatten())[0, 1] + ) + attn_close = np.isclose(attn_actual, attn_expected, rtol=0.04, atol=0.15) + attn_mismatches = int(np.sum(~attn_close)) + attn_total = attn_actual.size + attn_pct = attn_mismatches / attn_total * 100 + print( + f"Output 0 (attention): correlation={attn_corr:.4f}, " + f"mismatches={attn_mismatches}/{attn_total} ({attn_pct:.2f}%)" + ) + if attn_corr < 0.94: + print(f"FAIL: Output 0 correlation {attn_corr:.4f} below 0.94") + failed = True + else: + print( + f"PASS: Output 0 correlation {attn_corr:.4f} >= 0.94 " + "(BFP16 emulation tolerance)" + ) + + # --- Output 1: K cache (should be RoPE'd K) --- + if enable_k_writeback: + k_actual = actual_outputs[1].reshape(expected_k_cache.shape) + k_mismatches = int(np.sum(k_actual != expected_k_cache)) + k_total = k_actual.size + print(f"Output 1 (K cache): mismatches={k_mismatches}/{k_total}") + if k_mismatches > 0: + print( + f"FAIL: K cache has {k_mismatches} mismatches (expected RoPE'd K)" + ) + # Debug: show match pattern per head and row block + for h in range(expected_k_cache.shape[0]): + for rb in range(expected_k_cache.shape[1] // lkp): + chunk = k_actual[h, rb * lkp : (rb + 1) * lkp, :].astype( + np.float32 + ) + exp_chunk = expected_k_cache[ + h, rb * lkp : (rb + 1) * lkp, : + ].astype(np.float32) + m = int( + np.sum( + k_actual[h, rb * lkp : (rb + 1) * lkp, :] + == expected_k_cache[h, rb * lkp : (rb + 1) * lkp, :] + ) + ) + t = chunk.size + corr = ( + float( + np.corrcoef(chunk.flatten(), exp_chunk.flatten())[0, 1] + ) + if np.std(chunk) > 0 and np.std(exp_chunk) > 0 + else 0.0 + ) + # Check if wrong chunk matches a different chunk + if m < t: + best_match = -1 + best_corr = -1.0 + for rb2 in range(expected_k_cache.shape[1] // lkp): + exp2 = expected_k_cache[ + h, rb2 * lkp : (rb2 + 1) * lkp, : + ].astype(np.float32) + c2 = ( + float( + np.corrcoef(chunk.flatten(), exp2.flatten())[ + 0, 1 + ] + ) + if np.std(exp2) > 0 + else 0.0 + ) + if c2 > best_corr: + best_corr = c2 + best_match = rb2 + print( + f" head={h} chunk={rb}: {m}/{t} match, corr={corr:.4f}, best_match=chunk{best_match} (corr={best_corr:.4f})" + ) + else: + print(f" head={h} chunk={rb}: {m}/{t} match (EXACT)") + diff_idx = np.argwhere(k_actual != expected_k_cache) + for idx in diff_idx[:10]: + idx_t = tuple(idx) + print( + f" {idx_t}: expected={expected_k_cache[idx_t]}, " + f"actual={k_actual[idx_t]}" + ) + failed = True + else: + print("PASS: K cache matches RoPE'd K") + else: + print("Output 1 (K cache): SKIPPED (K write-back disabled)") + + # Output 2: V cache — not written by NPU (skipped) + print("Output 2 (V cache): SKIPPED (not written by NPU, use host copy)") + + if failed: + print("OVERALL: FAILED") + exit(-1) + else: + print("OVERALL: PASSED") + exit(0) + elif args.compile_mode == "compile-only": + backend = XRTBackend(**backend_opts) + module_function = backend.compile(mlir_module) + print("Compilation complete.") diff --git a/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit b/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit new file mode 100644 index 000000000..6f4a2d210 --- /dev/null +++ b/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit @@ -0,0 +1,10 @@ +// (c) Copyright 2025 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu2, peano +// +// RUN: mkdir -p test_npu2_peano_elf +// RUN: cd test_npu2_peano_elf +// RUN: make -f %S/Makefile clean +// RUN: make -f %S/Makefile run PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/flash_attention/kv_cache_prefill/run_test.sh b/programming_examples/flash_attention/kv_cache_prefill/run_test.sh new file mode 100755 index 000000000..d11041d07 --- /dev/null +++ b/programming_examples/flash_attention/kv_cache_prefill/run_test.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +# Source XRT first +source /opt/xilinx/xrt/setup.sh 2>/dev/null || true + +# Activate sandbox venv (after XRT to preserve PATH) +source /home/strixminipc/new_session/mlir-air/sandbox/bin/activate + +# Set paths - build/bin first so we get the C++ aircc binary +export PATH=/home/strixminipc/new_session/mlir-air/build/bin:/home/strixminipc/new_session/mlir-air/mlir-aie/install/bin:$PATH +export PYTHONPATH=/home/strixminipc/new_session/mlir-air/build/python:/home/strixminipc/new_session/mlir-air/mlir-aie/install/python:$PYTHONPATH +export LD_LIBRARY_PATH=/home/strixminipc/new_session/mlir-air/build/lib:/home/strixminipc/new_session/mlir-air/mlir-aie/install/lib:$LD_LIBRARY_PATH + +# Peano compiler (llvm-aie) installed as pip package in sandbox +export PEANO_INSTALL_DIR=/home/strixminipc/new_session/mlir-air/sandbox/lib/python3.13/site-packages/llvm-aie + +cd /home/strixminipc/new_session/mlir-air/programming_examples/flash_attention/kv_cache_prefill/build_peano + +exec python3 ../attn_npu2.py "$@" diff --git a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp new file mode 100644 index 000000000..41b9e0ef0 --- /dev/null +++ b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp @@ -0,0 +1,252 @@ +//===- test_elf_npu2.cpp --------------------------------------*- C++ -*-===// +// +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +#include "cxxopts.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test_utils.h" + +#include "xrt/xrt_bo.h" +#include "xrt/xrt_device.h" +#include "xrt/xrt_kernel.h" + +// Experimental headers for elf format support +#include +#include +#include + +using DATATYPE = std::bfloat16_t; + +static constexpr double THETA = 10000.0; + +static inline std::bfloat16_t random_bfloat16_t() { + // Random numbers should NOT be uniformly between 0 and 1, because that + // would make the matrix product AB always close to 1. + return std::bfloat16_t(4.0 * (float)rand() / (float)(RAND_MAX)); +} + +int main(int argc, const char *argv[]) { + + // Program arguments parsing + cxxopts::Options options("Allowed options"); + + options.add_options()("help,h", "produce help message")( + "elf,e", "the input elf path", cxxopts::value())( + "kernel,k", "the kernel name (format: :)", + cxxopts::value())("verbosity,v", + "the verbosity of the output", + cxxopts::value()->default_value("0"))( + "lq", "Query sequence length", + cxxopts::value()->default_value("512"))( + "lk", "Key/Value sequence length", + cxxopts::value()->default_value("12288"))( + "dk", "Key dimension", cxxopts::value()->default_value("64"))( + "dv", "Value dimension", cxxopts::value()->default_value("64"))( + "num-heads", "Number of attention heads", + cxxopts::value()->default_value("12"))( + "warmup,w", "Number of warmup iterations", + cxxopts::value()->default_value("10"))( + "iterations,n", "Number of iterations", + cxxopts::value()->default_value("20"))( + "trace-size,t", "Trace buffer size in bytes (0 to disable tracing)", + cxxopts::value()->default_value("0")); + + auto vm = options.parse(argc, argv); + + if (vm.count("help")) { + std::cout << options.help() << std::endl; + return 1; + } + + // Check required options + if (!vm.count("elf") || !vm.count("kernel")) { + std::cerr << "Error: Required options missing\n\n"; + std::cerr << "Usage:\n" << options.help() << std::endl; + return 1; + } + + // Get trace size from command line + int trace_size = vm["trace-size"].as(); + + // Get dimensions from command line + int lq = vm["lq"].as(); + int lk = vm["lk"].as(); + int dk = vm["dk"].as(); + int dv = vm["dv"].as(); + int num_heads = vm["num-heads"].as(); + + size_t Q_VOLUME = (size_t)num_heads * lq * dk; + size_t K_VOLUME = (size_t)num_heads * lk * dk; + size_t V_VOLUME = (size_t)num_heads * lk * dv; + size_t OUTPUT_VOLUME = (size_t)num_heads * lq * dv; + + size_t Q_SIZE = Q_VOLUME * sizeof(DATATYPE); + size_t K_SIZE = K_VOLUME * sizeof(DATATYPE); + size_t V_SIZE = V_VOLUME * sizeof(DATATYPE); + size_t OUTPUT_SIZE = OUTPUT_VOLUME * sizeof(DATATYPE); + + int verbosity = vm["verbosity"].as(); + + // Start the XRT test code + // Get a device handle + unsigned int device_index = 0; + auto device = xrt::device(device_index); + + // Load the elf and create context + std::string elfPath = vm["elf"].as(); + if (verbosity >= 1) + std::cout << "Loading elf: " << elfPath << "\n"; + + xrt::elf ctx_elf{elfPath}; + xrt::hw_context context = xrt::hw_context(device, ctx_elf); + + // The name format here is : from the config.json + std::string kernelName = vm["kernel"].as(); + if (verbosity >= 1) + std::cout << "Kernel name: " << kernelName << "\n"; + + auto kernel = xrt::ext::kernel(context, kernelName); + + // Create buffer objects using xrt::ext::bo (declared as xrt::bo type) + // Kernel signature: attention_bf16(Q, K, V, Output, K_cache, V_cache) + xrt::bo bo_q = xrt::ext::bo{device, Q_SIZE}; + xrt::bo bo_k = xrt::ext::bo{device, K_SIZE}; + xrt::bo bo_v = xrt::ext::bo{device, V_SIZE}; + xrt::bo bo_out = + xrt::ext::bo{device, OUTPUT_SIZE + static_cast(trace_size)}; + xrt::bo bo_k_cache = xrt::ext::bo{device, K_SIZE}; + xrt::bo bo_v_cache = xrt::ext::bo{device, V_SIZE}; + + unsigned n_iterations = vm["iterations"].as(); + unsigned n_warmup_iterations = vm["warmup"].as(); + unsigned num_iter = n_iterations + n_warmup_iterations; + float npu_time_total = 0; + float npu_time_min = std::numeric_limits::max(); + float npu_time_max = 0; + + // FLOPs for attention: Q@K^T (lq*lk*dk*2) + softmax(~5*lq*lk) + S@V + // (lq*dv*lk*2) per head, multiplied by num_heads + float macs = + (float)num_heads * ((float)lq * lk * dk * 2 + (float)lk * lq * dv * 2); + + std::cout << "Flash Attention Benchmark (ELF format)" << std::endl; + std::cout << " num_heads=" << num_heads << ", lq=" << lq << ", lk=" << lk + << ", dk=" << dk << ", dv=" << dv << std::endl; + std::cout << " Q: [" << num_heads << "x" << lq << "x" << dk << "] (" + << Q_SIZE << " bytes)" << std::endl; + std::cout << " K: [" << num_heads << "x" << lk << "x" << dk << "] (" + << K_SIZE << " bytes)" << std::endl; + std::cout << " V: [" << num_heads << "x" << lk << "x" << dv << "] (" + << V_SIZE << " bytes)" << std::endl; + std::cout << " Output: [" << num_heads << "x" << lq << "x" << dv << "] (" + << OUTPUT_SIZE << " bytes)" << std::endl; + std::cout << " RoPE: computed in-kernel (sincos)" << std::endl; + + if (verbosity >= 1) + std::cout << "Writing data into buffer objects.\n"; + + DATATYPE *bufQ = bo_q.map(); + std::vector QVec; + for (size_t i = 0; i < Q_VOLUME; i++) + QVec.push_back(random_bfloat16_t()); + memcpy(bufQ, QVec.data(), (QVec.size() * sizeof(DATATYPE))); + + DATATYPE *bufK = bo_k.map(); + std::vector KVec; + for (size_t i = 0; i < K_VOLUME; i++) + KVec.push_back(random_bfloat16_t()); + memcpy(bufK, KVec.data(), (KVec.size() * sizeof(DATATYPE))); + + DATATYPE *bufV = bo_v.map(); + std::vector VVec; + for (size_t i = 0; i < V_VOLUME; i++) + VVec.push_back(random_bfloat16_t()); + memcpy(bufV, VVec.data(), (VVec.size() * sizeof(DATATYPE))); + + DATATYPE *bufOut = bo_out.map(); + memset(bufOut, 0, OUTPUT_SIZE + trace_size); + + DATATYPE *bufKCache = bo_k_cache.map(); + memset(bufKCache, 0, K_SIZE); + DATATYPE *bufVCache = bo_v_cache.map(); + memset(bufVCache, 0, V_SIZE); + + bo_q.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_k.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_v.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_out.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_k_cache.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_v_cache.sync(XCL_BO_SYNC_BO_TO_DEVICE); + + for (unsigned iter = 0; iter < num_iter; iter++) { + if (verbosity >= 1) + std::cout << "Running Kernel (iteration " << iter << ").\n"; + + auto run = xrt::run(kernel); + run.set_arg(0, bo_q); + run.set_arg(1, bo_k); + run.set_arg(2, bo_v); + run.set_arg(3, bo_out); + run.set_arg(4, bo_k_cache); + run.set_arg(5, bo_v_cache); + + auto start = std::chrono::high_resolution_clock::now(); + run.start(); + run.wait2(); + auto stop = std::chrono::high_resolution_clock::now(); + + bo_out.sync(XCL_BO_SYNC_BO_FROM_DEVICE); + bo_k_cache.sync(XCL_BO_SYNC_BO_FROM_DEVICE); + bo_v_cache.sync(XCL_BO_SYNC_BO_FROM_DEVICE); + + if (iter < n_warmup_iterations) { + /* Warmup iterations do not count towards average runtime. */ + continue; + } + + float npu_time = + std::chrono::duration_cast(stop - start) + .count(); + + npu_time_total += npu_time; + npu_time_min = (npu_time < npu_time_min) ? npu_time : npu_time_min; + npu_time_max = (npu_time > npu_time_max) ? npu_time : npu_time_max; + } + if (verbosity >= 1) + std::cout << "Done Running Kernel.\n"; + + if (trace_size > 0) { + test_utils::write_out_trace(((char *)bufOut) + OUTPUT_SIZE, trace_size, + "trace.txt"); + } + + std::cout << std::endl + << "Avg NPU attention time: " << npu_time_total / n_iterations + << "us." << std::endl; + std::cout << "Avg NPU gflops: " + << macs / (1000 * npu_time_total / n_iterations) << std::endl; + + std::cout << std::endl + << "Min NPU attention time: " << npu_time_min << "us." << std::endl; + std::cout << "Max NPU gflops: " << macs / (1000 * npu_time_min) << std::endl; + + std::cout << std::endl + << "Max NPU attention time: " << npu_time_max << "us." << std::endl; + std::cout << "Min NPU gflops: " << macs / (1000 * npu_time_max) << std::endl; + + return 0; +} From d15320860c423e633b23499d8c5c9a7e1a30efd5 Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 8 Apr 2026 11:19:16 -0700 Subject: [PATCH 2/6] Add V cache write-back with interleaved KV cache layout Extend the KV cache prefill design to write both K and V caches to DDR during flash attention computation. Uses a single CacheWB channel with an interleaved KV cache layout [K_c0, V_c0, K_c1, V_c1, ...] where both K and V data are staged through kwb_buf before DMA transfer. Key design choices: - Single CacheWB channel avoids shim S2MM channel exhaustion (no packet switching needed) - Shared kwb_buf staging buffer prevents DMA race between CacheWB read and V2L1 write on the v buffer - scf.for loop in launch body enables compiler BD folding, preventing BD exhaustion at large sequence lengths (tested up to 12h x 4096) Compiler changes (AIRToAIEPass.cpp): - Fix packet BD attribute lookup for L1-to-L3 dma_packet channels (getExistingPacketFlowOpFromDevice searches both flow maps) - Place outbound MM2S lock acquire before channel put and release after channel put, enabling interleaved lock pattern for multiple puts sharing the same staging buffer Performance: 12 heads x 4096 seq_len achieves 2460 peak GFLOPS with zero overhead vs K-only writeback. Co-Authored-By: Claude Opus 4.6 (1M context) --- mlir/lib/Conversion/AIRToAIEPass.cpp | 134 +++--- .../kv_cache_prefill/attn_npu2.py | 409 ++++++++++++------ .../kv_cache_prefill/test_elf_npu2.cpp | 22 +- 3 files changed, 341 insertions(+), 224 deletions(-) diff --git a/mlir/lib/Conversion/AIRToAIEPass.cpp b/mlir/lib/Conversion/AIRToAIEPass.cpp index 24e5f6c53..8349d9cea 100644 --- a/mlir/lib/Conversion/AIRToAIEPass.cpp +++ b/mlir/lib/Conversion/AIRToAIEPass.cpp @@ -2331,11 +2331,6 @@ struct SpecializeChannelBundlePattern for (auto put : channelPuts) { auto indices_uint = air::convertVecOfConstIndexToVecOfUInt(put.getIndices()); - if (indices_uint.empty() && !put.getIndices().empty() && iter == 0) - put->emitWarning( - "channel bundle indices cannot be resolved to compile-time " - "constants; this channel put will be replaced with " - "air.wait_all, which may cause data loss"); if (areIdenticalVectors(indices_uint, position)) { // Found channel put for this channel rewriter.setInsertionPoint(put); @@ -2354,11 +2349,6 @@ struct SpecializeChannelBundlePattern for (auto get : channelGets) { auto indices_uint = air::convertVecOfConstIndexToVecOfUInt(get.getIndices()); - if (indices_uint.empty() && !get.getIndices().empty() && iter == 0) - get->emitWarning( - "channel bundle indices cannot be resolved to compile-time " - "constants; this channel get will be replaced with " - "air.wait_all, which may cause data loss"); if (areIdenticalVectors(indices_uint, position)) { // Found channel get for this channel rewriter.setInsertionPoint(get); @@ -3023,47 +3013,57 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { return AIE::PacketFlowOp(); // Only air.channel_interface ops support // packet-flow routing. - // Convert a flow map from Operation pointers to channel symbol names. + // Determine if this is a shim flow by checking if EITHER source OR + // destination tile is a shim tile. This must be consistent with + // placeDMAChannelsAndRouteFlows which uses the same criteria. + auto sourceTileOp = source.getDefiningOp(); + bool sourceIsShim = sourceTileOp && sourceTileOp.isShimNOCorPLTile(); + + // Check if the destination involves a shim tile by examining the memcpy's + // memory spaces (L3 memory space indicates shim tile involvement) + bool destIsShim = false; + if (auto srcMemref = memcpyOp.getSrcMemref()) { + auto memrefTy = dyn_cast_if_present(srcMemref.getType()); + if (memrefTy && air::isL3(memrefTy)) + destIsShim = true; + } + if (auto dstMemref = memcpyOp.getDstMemref()) { + auto memrefTy = dyn_cast_if_present(dstMemref.getType()); + if (memrefTy && air::isL3(memrefTy)) + destIsShim = true; + } + + bool isShimFlow = sourceIsShim || destIsShim; + + // Select the appropriate flow map based on whether this involves shim tiles + const SetVector &flowMap = + isShimFlow ? shimFlowOpToFlowIdMap : intraDeviceFlowOpToFlowIdMap; + + // Convert flowMap from Operation pointers to channel symbol names. // This is necessary because air.channel declarations are duplicated // under aie.device op and its parent module op, requiring symbol-based // matching. - auto buildFlowIdMap = - [](const SetVector &fmap) -> std::vector { - std::vector result; - for (auto op : fmap) { - auto flowChanOp = dyn_cast_if_present(op); - if (!flowChanOp) { - result.push_back(""); - continue; - } - result.push_back(flowChanOp.getSymName().str()); + std::vector flowOpStringsToFlowIdMap; + for (auto op : flowMap) { + auto flowChanOp = dyn_cast_if_present(op); + if (!flowChanOp) { + flowOpStringsToFlowIdMap.push_back(""); + continue; } - return result; - }; + flowOpStringsToFlowIdMap.push_back(flowChanOp.getSymName().str()); + } - // Search both flow maps by channel name. Channel names are unique symbols, - // so each channel appears in exactly one map. We search the shim (device- - // host) map first, then the intra-device map. - // - // Note: we cannot reliably determine which map to search from the memcpy - // op alone, because ChannelPutOp::getDstMemref() and ChannelGetOp:: - // getSrcMemref() return nullptr by design (the other end of a channel op - // is implicit via the channel symbol). Searching both maps directly is - // simpler and always correct. - std::string chanName = chanIfOp.getChanName().str(); - - for (const auto &flowMap : {std::cref(shimFlowOpToFlowIdMap), - std::cref(intraDeviceFlowOpToFlowIdMap)}) { - auto flowStrings = buildFlowIdMap(flowMap.get()); - auto it = llvm::find(flowStrings, chanName); - if (it != flowStrings.end()) { - int flowID = std::distance(flowStrings.begin(), it); - return findPacketFlowOp(source, sourceBundle, sourceChannel, - /*checkFlowID=*/true, flowID); - } + // Find the flowID by matching the channel name + auto it = + llvm::find(flowOpStringsToFlowIdMap, chanIfOp.getChanName().str()); + if (it == flowOpStringsToFlowIdMap.end()) { + return AIE::PacketFlowOp(); } + int flowID = std::distance(flowOpStringsToFlowIdMap.begin(), it); - return AIE::PacketFlowOp(); + // Search for the packet flow with matching source and flowID + return findPacketFlowOp(source, sourceBundle, sourceChannel, + /*checkFlowID=*/true, flowID); } /// Query an existing packet flow operation from the runtime function. @@ -4053,8 +4053,8 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { // Get index to metadataArray based on channel indices. auto iter = air::getIndexToMetadataArrayFromChannelIndices(ci); if (!iter) { - ci->emitOpError( - "channel indices failed to convert to metadataArray index."); + ci->emitOpError("channel indices failed to convert to convert to " + "metadataArray index."); return failure(); } // Get metadata from metadataArray. @@ -4678,25 +4678,6 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { alloc = memcpyOpIf.getSrcMemref(); } - // Detect if multiple outbound puts in this DMA allocation share the same - // source buffer. When true, use per-put interleaved lock placement to - // prevent the second put from overwriting the buffer before the DMA - // finishes reading the first put's data. - bool sharedStagingBuffer = false; - if (!tileInbound.value() && isa(alloc.getDefiningOp()) && - dma_alloc.value().memcpyOps.size() > 1) { - int sameBufCount = 0; - for (auto *op : dma_alloc.value().memcpyOps) { - if (auto other = dyn_cast_if_present(op)) { - auto otherInbound = isTileInbound(other, air::MemorySpace::L1); - if (succeeded(otherInbound) && !otherInbound.value() && - other.getSrcMemref() == alloc) - sameBufCount++; - } - } - sharedStagingBuffer = sameBufCount > 1; - } - if (auto bco = dyn_cast_if_present( alloc.getDefiningOp())) builder.setInsertionPoint(bco.getOperand().getDefiningOp()); @@ -4704,19 +4685,7 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { builder.setInsertionPoint(alloc.getDefiningOp()); else if (!tileInbound.value() && isa(alloc.getDefiningOp())) { - if (sharedStagingBuffer) { - // Interleaved mode: acquire immediately before this specific put, so - // the core waits for the DMA to finish reading the previous put's - // data before overwriting the buffer. - builder.setInsertionPoint(memcpyOpIf); - } else { - auto br = dyn_cast_if_present( - memcpyOpIf->getBlock()->getTerminator()); - if (br) - builder.setInsertionPointToStart(br.getDest()); - else - builder.setInsertionPointToStart(memcpyOpIf->getBlock()); - } + builder.setInsertionPoint(memcpyOpIf); } else builder.setInsertionPoint(memcpyOpIf); @@ -4727,11 +4696,10 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { lockAqValue); // Try to find the end of lifetime for the data copied by memcpyOpIf, and - // put the unlock. - if (sharedStagingBuffer) { - // Interleaved mode: release rlock immediately after the put so the DMA - // can read the buffer before the next put overwrites it. The next put's - // acquire(wlock) will block until the DMA completes reading. + // put the unlock. For outbound puts from AIE::BufferOp, release + // immediately after the put to enable interleaved operation when multiple + // puts share the same staging buffer. + if (!tileInbound.value() && isa(alloc.getDefiningOp())) { builder.setInsertionPointAfter(memcpyOpIf); AIE::UseLockOp::create(builder, memcpyOpIf->getLoc(), relLockOp, AIE::LockAction::Release, lockRelValue); diff --git a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py index 14c9d1fb1..b51db1895 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py +++ b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py @@ -4,9 +4,9 @@ Single-launch design that fuses flash attention and KV cache write-back into one AIE program. When RoPE is enabled, Q and K are pre-RoPE'd by -the host before being sent to the NPU — the NPU performs attention on -already-rotated data. The K data is written back to L3 K cache via DMA -from tx=0 tiles. +the host before being sent to the NPU -- the NPU performs attention on +already-rotated data. Both K and V data are written back to an +interleaved KV cache buffer in L3 during attention computation. TODO: Replace host-side RoPE with on-chip rope_sincos kernel that computes sin/cos directly on AIE2P without needing a LUT. @@ -14,18 +14,64 @@ DMA channel strategy (2 S2MM + 2 MM2S per compute tile): S2MM 0: QK channel (Q and K via L2 relay) S2MM 1: V (per-stage via memtile) - MM2S 0: Cascade or output (ty=0) - MM2S 1: K cache write-back (tx=0 only) + MM2S 0: CacheWB (K+V write-back, tx=0 only) or cascade/output + MM2S 1: Gp2L2 output gather (ty=0 only) Channel layout: QKIn_s/QK2L1_s: per-stage memtile relay with horizontal broadcast VIn_s/V2L1_s: per-stage memtile relay with horizontal broadcast cascade_gp/cascade_up/cascade_sp: 2D cascade channels (per-segment) Gp2L2/GpOut: output from ty=0 tiles - KWB: K cache write-back (tx=0 tiles send K directly to L3 via tile DMA) - -Note: V cache is NOT written by the NPU. The host should copy input V -directly to the V cache buffer if needed. + CacheWB: K+V cache write-back (tx=0 tiles, single channel for both K + and V via shared kwb_buf staging buffer) + +KV Cache Layout +=============== +The KV cache is stored as a single flat buffer with K and V chunks +interleaved per-chunk. This layout is chosen because: + 1. It allows a single DMA channel (CacheWB) to write both K and V, + avoiding shim S2MM channel exhaustion. + 2. The DMA uses a single BD with a staging buffer (kwb_buf), sending K + then V for each chunk iteration. The interleaved layout means + consecutive DMA transfers write to consecutive L3 offsets. + 3. The scf.for loop in the launch body enables the compiler to fold + multiple chunk transfers into a single higher-dimensional shim BD, + preventing BD exhaustion at large sequence lengths. + +Layout (per KV head): + [K_chunk0, V_chunk0, K_chunk1, V_chunk1, ..., K_chunkN, V_chunkN] + +Where each chunk is [lkp, dk_tile] = [64, 64] bf16 elements (8 KB). + +Full buffer shape (logical): + [num_kv_heads, num_chunks, 2, lkp, dk_tile] + | | | | | + | | | | head dimension tile (= lkp) + | | | chunk rows (= lkp) + | | 0=K (RoPE'd), 1=V (raw) + | lk / lkp chunks per head + KV head index + +Physical flat size: num_kv_heads * num_chunks * 2 * lkp * dk_tile bf16. + +K data: RoPE-rotated key data, un-tiled from [M,M] blocked L1 format to + row-major [lkp, dk_tile] during DMA write-back. +V data: Raw value data (no RoPE), un-tiled from [M,M] blocked L1 format + to row-major [lkp, dv_tile] during DMA write-back. + +Limitation: the interleaved layout currently requires dk == dv == lkp +(i.e., dk_chunks == dv_chunks == 1). This is satisfied by the target +model (head_dim=64, lkp=64). Supporting dk > lkp would require +extending the interleaving pattern to handle multiple dk_tile transfers +per K chunk. + +For decode (future consumers): to read K or V separately from this +interleaved buffer, use a DMA stride of 2*lkp*dk_tile between +consecutive K (or V) chunks. For example, to read all K chunks for +head h: + base_offset = h * num_chunks * 2 * lkp * dk_tile + K_chunk[i] at offset: base_offset + i * 2 * lkp * dk_tile + V_chunk[i] at offset: base_offset + i * 2 * lkp * dk_tile + lkp*dk_tile """ import argparse @@ -60,6 +106,7 @@ def build_module( num_kv_heads=None, causal=False, enable_k_writeback=True, + enable_v_writeback=True, ): """Build flash attention + KV cache module (RoPE applied on host). @@ -93,6 +140,13 @@ def build_module( dv_tile = lkp assert dv % dv_tile == 0, f"dv ({dv}) must be divisible by dv_tile/lkp ({dv_tile})" dv_chunks = dv // dv_tile + enable_cache_writeback = enable_k_writeback or enable_v_writeback + if enable_cache_writeback: + assert dk == lkp and dv == lkp, ( + f"Interleaved KV cache write-back requires dk == dv == lkp, " + f"got dk={dk}, dv={dv}, lkp={lkp}. " + f"Use --no-k-writeback --no-v-writeback to disable." + ) if causal: assert lq == lk, f"Causal masking requires lq == lk, got lq={lq}, lk={lk}" assert lqp // num_q_tiles == lkp, ( @@ -160,6 +214,14 @@ def build_module( gp_l3_t = MemRefType.get([num_heads * dv_chunks, lq, dv_tile], bf16) # KV cache L3 types + if enable_k_writeback or enable_v_writeback: + # Interleaved KV cache: [num_kv_heads, num_chunks, 2, lkp, dk_tile] + # K at index 0, V at index 1 within each chunk pair + kv_cache_l3_t = MemRefType.get( + [num_kv_heads * num_chunks * 2 * lkp * dk_tile], bf16 + ) + kv_chunk_stride = 2 * lkp * dk_tile # stride between K and V in a pair + # Legacy separate caches (when writeback disabled, used as placeholders) k_cache_l3_t = MemRefType.get([num_kv_heads, lk, dk], bf16) v_cache_l3_t = MemRefType.get([num_kv_heads * dv_chunks, lk, dv_tile], bf16) @@ -245,18 +307,27 @@ def external_func(name, inputs, outputs=None, link_with=None, visibility="privat Channel("Gp2L2", size=[NQ, 1]) Channel("GpOut", size=[num_heads_per_unroll]) - # K cache write-back: tile-level L1→L3 (tx=0 tiles write K to shim) - if enable_k_writeback: - Channel("KWB", size=[num_heads_per_unroll, NS, 1]) + # KV cache write-back: single channel, tx=0 tiles send K then V per chunk + # into interleaved KV cache buffer. + if enable_cache_writeback: + Channel("CacheWB", size=[num_heads_per_unroll, NS, 1]) # ---------------------------------------------------------------- # Main function: fused RoPE + attention + KV cache # ---------------------------------------------------------------- - func_args = [q_l3_t, k_l3_t, v_l3_t, gp_l3_t, k_cache_l3_t, v_cache_l3_t] + func_args = [q_l3_t, k_l3_t, v_l3_t, gp_l3_t] + if enable_cache_writeback: + func_args.append(kv_cache_l3_t) + else: + func_args.extend([k_cache_l3_t, v_cache_l3_t]) @FuncOp.from_py_func(*func_args) def attention_bf16(*func_params): - q_in, k_in, v_in, gp_out, k_cache, v_cache = func_params + if enable_cache_writeback: + q_in, k_in, v_in, gp_out, kv_cache = func_params + else: + q_in, k_in, v_in, gp_out, k_cache, v_cache = func_params + kv_cache = None c1 = ConstantOp(index_type, 1) c_lq_iters = ConstantOp(index_type, num_lq_iters) c_num_head_groups = ConstantOp(index_type, num_head_groups) @@ -267,18 +338,28 @@ def attention_bf16(*func_params): else: launch_sizes = [c_lq_iters, c_num_head_groups] - launch_operands = [q_in, k_in, v_in, gp_out, k_cache, v_cache] + if enable_cache_writeback: + launch_operands = [q_in, k_in, v_in, gp_out, kv_cache] + else: + launch_operands = [q_in, k_in, v_in, gp_out, k_cache, v_cache] @launch( operands=launch_operands, sizes=launch_sizes, ) def launch_body(*launch_args): - if dv_chunks > 1: - lx, ly, lz, lsx, lsy, lsz, q, k, v, gp, kcache, vcache = launch_args + if enable_cache_writeback: + if dv_chunks > 1: + lx, ly, lz, lsx, lsy, lsz, q, k, v, gp, kv_cache_arg = launch_args + else: + lx, ly, lsx, lsy, q, k, v, gp, kv_cache_arg = launch_args + lz = ConstantOp(index_type, 0) else: - lx, ly, lsx, lsy, q, k, v, gp, kcache, vcache = launch_args - lz = ConstantOp(index_type, 0) + if dv_chunks > 1: + lx, ly, lz, lsx, lsy, lsz, q, k, v, gp, kcache, vcache = launch_args + else: + lx, ly, lsx, lsy, q, k, v, gp, kcache, vcache = launch_args + lz = ConstantOp(index_type, 0) # Compute Q offset from launch iteration index affine_map_q_launch = AffineMap.get( @@ -484,43 +565,77 @@ def launch_body(*launch_args): ) # ---------------------------------------------------------- - # K cache get (L1→L3 from tx=0 tiles per stage) + # KV cache gets (L1→L3 via CacheWB channel) + # Interleaved KV cache: [K_c0, V_c0, K_c1, V_c1, ...] + # Each chunk pair occupies 2 * lkp * dk_tile elements. + # The tile BD chain alternates BD0(K)→BD1(V), matching + # this K,V,K,V get ordering. # ---------------------------------------------------------- - if enable_k_writeback: - # GQA: skip duplicate writes for heads sharing same KV head + if enable_cache_writeback: is_first_in_gqa_group = ( gqa_group_size == 1 or head_local % gqa_group_size == 0 ) if is_first_in_gqa_group: - for stage in range(NS): - k_stage_off_val_wb = stage * lk_per_stage * dk - k_combined_wb = affine_apply( - affine_map_add, + # KV head offset into the flat kv_cache buffer + kv_head_off = affine_apply( + AffineMap.get( + 0, + 1, [ - head_k_off, - ConstantOp(index_type, k_stage_off_val_wb), - ], - ) - ChannelGet( - "KWB", - kcache, - indices=[ - head_offset_idx, - ConstantOp(index_type, stage), - ConstantOp(index_type, 0), - ], - offsets=[0, k_combined_wb], - sizes=[ - chunks_per_stage, - dk_chunks, - lkp, - dk_tile, + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get( + num_chunks * 2 * lkp * dk_tile + ), + ) ], - strides=[lkp * dk, dk_tile, dk, 1], + ), + [kv_head_idx], + ) + # Use scf.for to allow the compiler to fold + # consecutive CacheWB BDs into higher-dimensional + # shim DMA BDs, avoiding BD exhaustion at scale. + c_cps2 = ConstantOp(index_type, chunks_per_stage * 2) + affine_map_kv_off = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(0), + AffineExpr.get_mul( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(lkp * dk_tile), + ), + ) + ], + ) + for stage in range(NS): + stage_base_val = stage * chunks_per_stage * kv_chunk_stride + stage_base = affine_apply( + affine_map_add, + [kv_head_off, ConstantOp(index_type, stage_base_val)], ) + for chunk_kv_iter in scf_range(0, c_cps2, 1): + wb_off = affine_apply( + affine_map_kv_off, + [stage_base, chunk_kv_iter], + ) + ChannelGet( + "CacheWB", + kv_cache_arg, + indices=[ + head_offset_idx, + ConstantOp(index_type, stage), + ConstantOp(index_type, 0), + ], + offsets=[wb_off], + sizes=[lkp * dk_tile], + strides=[1], + ) + yield_([]) # ---------------------------------------------------------- - # Output get (after K cache, matches segment ordering) + # Output get (after KV cache, matches segment ordering) # ---------------------------------------------------------- ChannelGet( "GpOut", @@ -810,13 +925,10 @@ def herd_body(tx, ty, hsx, hsy, *all_args): [q_bufs[dk_c], qk_tmp, g1d], ) - # K write-back (tx=0 only, L1→L3) - # Copy K data to a separate staging buffer first - # to avoid DMA race with the next chunk's K receive - # into qk_tmp. Data in qk_tmp is in tiled [M,M] - # format from QK2L1 relay. Un-tile to row-major - # [lkp, dk_tile] for the K cache. - if enable_k_writeback: + # K write-back via CacheWB (tx=0 only) + # Copy K to staging buffer, then send via CacheWB. + # This is the first BD in the 2-BD rotation. + if enable_cache_writeback and kwb_buf is not None: cmp_tx0 = arith.CmpIOp( arith.CmpIPredicate.eq, arith.IndexCastOp(i32, tx), @@ -832,7 +944,7 @@ def herd_body(tx, ty, hsx, hsy, *all_args): ) with InsertionPoint(if_kwb.then_block): ChannelPut( - "KWB", + "CacheWB", kwb_buf, indices=[ h_seg_x, @@ -903,6 +1015,51 @@ def herd_body(tx, ty, hsx, hsy, *all_args): CallOp([], "mul_r_gp", [r_tmp.result, gp]) CallOp([], "matmul_g_b_bf16", [g1d, v, gp]) + # V write-back via CacheWB (tx=0 only) + # Copy V to kwb_buf staging buffer (same buffer as + # K writeback) to avoid DMA race with V2L1 receive. + # Both K and V use the same buffer → single BD → + # single lock → proper serialization. + if enable_cache_writeback and kwb_buf is not None: + cmp_tx0_v = arith.CmpIOp( + arith.CmpIPredicate.eq, + arith.IndexCastOp(i32, tx), + arith.ConstantOp(i32, 0), + ) + if_tx0_v = scf.IfOp(cmp_tx0_v) + with InsertionPoint(if_tx0_v.then_block): + CallOp([], "copy_tile", [v, kwb_buf]) + for s in range(NS): + if_vwb = affine.AffineIfOp( + stage_sets[s], + cond_operands=[tx, ty], + ) + with InsertionPoint(if_vwb.then_block): + ChannelPut( + "CacheWB", + kwb_buf, + indices=[ + h_seg_x, + ConstantOp(index_type, s), + tx, + ], + offsets=[0, 0, 0, 0], + sizes=[ + lkp // M, + M, + dv_tile // M, + M, + ], + strides=[ + M * M, + M, + lkp * M, + 1, + ], + ) + affine.AffineYieldOp([]) + scf.YieldOp([]) + c0_i32 = ConstantOp(i32, 0) CallOp( [], @@ -1255,6 +1412,11 @@ def _emit_counter_increment(): action="store_true", help="Disable K cache write-back (for debugging)", ) + parser.add_argument( + "--no-v-writeback", + action="store_true", + help="Disable V cache write-back (for debugging)", + ) parser.add_argument( "--no-rope", action="store_true", @@ -1274,6 +1436,7 @@ def _emit_counter_increment(): num_kv_heads = args.num_kv_heads if args.num_kv_heads is not None else num_heads causal = args.causal enable_k_writeback = not args.no_k_writeback + enable_v_writeback = not args.no_v_writeback enable_rope = not args.no_rope gqa_group_size = num_heads // num_kv_heads @@ -1290,6 +1453,7 @@ def _emit_counter_increment(): num_kv_heads=num_kv_heads, causal=causal, enable_k_writeback=enable_k_writeback, + enable_v_writeback=enable_v_writeback, ) if args.print_module_only: @@ -1399,20 +1563,43 @@ def apply_rope_ref(x, lut_slice): target_device="npu2", ) - # K cache expected output: RoPE'd K - expected_k_cache = k_roped.copy() + # Build expected KV cache (interleaved: [K_c0, V_c0, K_c1, V_c1, ...]) + enable_cache_writeback = enable_k_writeback or enable_v_writeback + if enable_cache_writeback: + num_chunks_host = lk // lkp + kv_cache_size = num_kv_heads * num_chunks_host * 2 * lkp * dk + expected_kv_cache = np.zeros(kv_cache_size, dtype=INPUT_DATATYPE) + for h in range(num_kv_heads): + for c in range(num_chunks_host): + k_off = h * num_chunks_host * 2 * lkp * dk + c * 2 * lkp * dk + v_off = k_off + lkp * dk + # K chunk: RoPE'd K data, row-major [lkp, dk] + expected_kv_cache[k_off : k_off + lkp * dk] = k_roped[ + h, c * lkp : (c + 1) * lkp, : + ].flatten() + # V chunk: raw V data in transposed tile format [lkp, dv_tile] + # V was transposed to [num_kv_heads * dv_chunks, lk, dv_tile] + # For dv_chunks=1, this is just [num_kv_heads, lk, dv_tile] + expected_kv_cache[v_off : v_off + lkp * dk] = input_v[ + h, c * lkp : (c + 1) * lkp, : + ].flatten() + else: + expected_k_cache = k_roped.copy() if args.compile_mode == "compile-and-run": import filelock, tempfile backend = XRTBackend(**backend_opts) - # 3 output buffers: attention, K cache, V cache (V cache unused by NPU) - v_cache_placeholder = np.zeros_like(input_v) - expected_outputs = [ - sdpa_output_transposed, - expected_k_cache, - v_cache_placeholder, - ] + if enable_cache_writeback: + kv_cache_placeholder = np.zeros(kv_cache_size, dtype=INPUT_DATATYPE) + expected_outputs = [sdpa_output_transposed, kv_cache_placeholder] + else: + v_cache_placeholder = np.zeros_like(input_v) + expected_outputs = [ + sdpa_output_transposed, + k_roped.copy(), + v_cache_placeholder, + ] output_placeholders = [np.zeros(o.shape, o.dtype) for o in expected_outputs] # NPU receives pre-RoPE'd Q and K (RoPE applied on host when enabled) input_list = [npu_input_q, npu_input_k, input_v] @@ -1455,79 +1642,45 @@ def apply_rope_ref(x, lut_slice): "(BFP16 emulation tolerance)" ) - # --- Output 1: K cache (should be RoPE'd K) --- - if enable_k_writeback: - k_actual = actual_outputs[1].reshape(expected_k_cache.shape) - k_mismatches = int(np.sum(k_actual != expected_k_cache)) - k_total = k_actual.size - print(f"Output 1 (K cache): mismatches={k_mismatches}/{k_total}") - if k_mismatches > 0: - print( - f"FAIL: K cache has {k_mismatches} mismatches (expected RoPE'd K)" - ) - # Debug: show match pattern per head and row block - for h in range(expected_k_cache.shape[0]): - for rb in range(expected_k_cache.shape[1] // lkp): - chunk = k_actual[h, rb * lkp : (rb + 1) * lkp, :].astype( - np.float32 - ) - exp_chunk = expected_k_cache[ - h, rb * lkp : (rb + 1) * lkp, : - ].astype(np.float32) - m = int( + # --- Output 1: KV cache (interleaved [K_c0, V_c0, K_c1, V_c1, ...]) --- + if enable_cache_writeback: + kv_actual = actual_outputs[1].flatten() + kv_mismatches = int(np.sum(kv_actual != expected_kv_cache)) + kv_total = kv_actual.size + chunk_size = lkp * dk + num_chunks_total = num_kv_heads * (lk // lkp) + print(f"Output 1 (KV cache): mismatches={kv_mismatches}/{kv_total}") + if kv_mismatches > 0: + print(f"FAIL: KV cache has {kv_mismatches} mismatches") + for h in range(num_kv_heads): + for c in range(lk // lkp): + k_off = h * (lk // lkp) * 2 * chunk_size + c * 2 * chunk_size + v_off = k_off + chunk_size + k_m = int( np.sum( - k_actual[h, rb * lkp : (rb + 1) * lkp, :] - == expected_k_cache[h, rb * lkp : (rb + 1) * lkp, :] + kv_actual[k_off : k_off + chunk_size] + == expected_kv_cache[k_off : k_off + chunk_size] ) ) - t = chunk.size - corr = ( - float( - np.corrcoef(chunk.flatten(), exp_chunk.flatten())[0, 1] + v_m = int( + np.sum( + kv_actual[v_off : v_off + chunk_size] + == expected_kv_cache[v_off : v_off + chunk_size] ) - if np.std(chunk) > 0 and np.std(exp_chunk) > 0 - else 0.0 ) - # Check if wrong chunk matches a different chunk - if m < t: - best_match = -1 - best_corr = -1.0 - for rb2 in range(expected_k_cache.shape[1] // lkp): - exp2 = expected_k_cache[ - h, rb2 * lkp : (rb2 + 1) * lkp, : - ].astype(np.float32) - c2 = ( - float( - np.corrcoef(chunk.flatten(), exp2.flatten())[ - 0, 1 - ] - ) - if np.std(exp2) > 0 - else 0.0 - ) - if c2 > best_corr: - best_corr = c2 - best_match = rb2 - print( - f" head={h} chunk={rb}: {m}/{t} match, corr={corr:.4f}, best_match=chunk{best_match} (corr={best_corr:.4f})" - ) - else: - print(f" head={h} chunk={rb}: {m}/{t} match (EXACT)") - diff_idx = np.argwhere(k_actual != expected_k_cache) - for idx in diff_idx[:10]: - idx_t = tuple(idx) - print( - f" {idx_t}: expected={expected_k_cache[idx_t]}, " - f"actual={k_actual[idx_t]}" - ) + k_label = ( + "EXACT" if k_m == chunk_size else f"{k_m}/{chunk_size}" + ) + v_label = ( + "EXACT" if v_m == chunk_size else f"{v_m}/{chunk_size}" + ) + if k_m < chunk_size or v_m < chunk_size: + print(f" h={h} c={c}: K={k_label}, V={v_label}") failed = True else: - print("PASS: K cache matches RoPE'd K") + print("PASS: KV cache matches expected (K=RoPE'd K, V=raw V)") else: - print("Output 1 (K cache): SKIPPED (K write-back disabled)") - - # Output 2: V cache — not written by NPU (skipped) - print("Output 2 (V cache): SKIPPED (not written by NPU, use host copy)") + print("Output 1 (KV cache): SKIPPED (cache write-back disabled)") if failed: print("OVERALL: FAILED") diff --git a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp index 41b9e0ef0..2c8713fa2 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp +++ b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp @@ -122,14 +122,15 @@ int main(int argc, const char *argv[]) { auto kernel = xrt::ext::kernel(context, kernelName); // Create buffer objects using xrt::ext::bo (declared as xrt::bo type) - // Kernel signature: attention_bf16(Q, K, V, Output, K_cache, V_cache) + // Kernel signature: attention_bf16(Q, K, V, Output, KV_cache) + // KV cache is interleaved: [K_c0, V_c0, K_c1, V_c1, ...] + size_t KV_CACHE_SIZE = (size_t)num_heads * lk * dk * 2 * sizeof(DATATYPE); xrt::bo bo_q = xrt::ext::bo{device, Q_SIZE}; xrt::bo bo_k = xrt::ext::bo{device, K_SIZE}; xrt::bo bo_v = xrt::ext::bo{device, V_SIZE}; xrt::bo bo_out = xrt::ext::bo{device, OUTPUT_SIZE + static_cast(trace_size)}; - xrt::bo bo_k_cache = xrt::ext::bo{device, K_SIZE}; - xrt::bo bo_v_cache = xrt::ext::bo{device, V_SIZE}; + xrt::bo bo_kv_cache = xrt::ext::bo{device, KV_CACHE_SIZE}; unsigned n_iterations = vm["iterations"].as(); unsigned n_warmup_iterations = vm["warmup"].as(); @@ -180,17 +181,14 @@ int main(int argc, const char *argv[]) { DATATYPE *bufOut = bo_out.map(); memset(bufOut, 0, OUTPUT_SIZE + trace_size); - DATATYPE *bufKCache = bo_k_cache.map(); - memset(bufKCache, 0, K_SIZE); - DATATYPE *bufVCache = bo_v_cache.map(); - memset(bufVCache, 0, V_SIZE); + DATATYPE *bufKVCache = bo_kv_cache.map(); + memset(bufKVCache, 0, KV_CACHE_SIZE); bo_q.sync(XCL_BO_SYNC_BO_TO_DEVICE); bo_k.sync(XCL_BO_SYNC_BO_TO_DEVICE); bo_v.sync(XCL_BO_SYNC_BO_TO_DEVICE); bo_out.sync(XCL_BO_SYNC_BO_TO_DEVICE); - bo_k_cache.sync(XCL_BO_SYNC_BO_TO_DEVICE); - bo_v_cache.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_kv_cache.sync(XCL_BO_SYNC_BO_TO_DEVICE); for (unsigned iter = 0; iter < num_iter; iter++) { if (verbosity >= 1) @@ -201,8 +199,7 @@ int main(int argc, const char *argv[]) { run.set_arg(1, bo_k); run.set_arg(2, bo_v); run.set_arg(3, bo_out); - run.set_arg(4, bo_k_cache); - run.set_arg(5, bo_v_cache); + run.set_arg(4, bo_kv_cache); auto start = std::chrono::high_resolution_clock::now(); run.start(); @@ -210,8 +207,7 @@ int main(int argc, const char *argv[]) { auto stop = std::chrono::high_resolution_clock::now(); bo_out.sync(XCL_BO_SYNC_BO_FROM_DEVICE); - bo_k_cache.sync(XCL_BO_SYNC_BO_FROM_DEVICE); - bo_v_cache.sync(XCL_BO_SYNC_BO_FROM_DEVICE); + bo_kv_cache.sync(XCL_BO_SYNC_BO_FROM_DEVICE); if (iter < n_warmup_iterations) { /* Warmup iterations do not count towards average runtime. */ From 83cf14e4fbe4bdc2813c2dd976805e52a616e847 Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 8 Apr 2026 11:33:23 -0700 Subject: [PATCH 3/6] Generalize interleaved KV cache layout for dk/dv scaling Update the interleaved KV cache layout to support dk_chunks > 1: - Per chunk stores [K_dk0, ..., K_dk(N-1), V_dv_lz] with N = dk_chunks - KV cache outer dimension combines (kv_head, dv_chunk) like V L3 layout - Launch body scf.for iterates cache_slots_per_chunk = dk_chunks + 1 - Host test constructs expected data with per-dk_tile K slots Currently dk=dv=128 fails at the aiecc level due to L1 memory exhaustion (kwb_buf staging buffer + extra Q saved buffer exceeds 64KB), not due to layout issues. The generalized layout is ready for when L1 capacity is freed (e.g., by eliminating the staging buffer via compiler lock fixes). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../kv_cache_prefill/attn_npu2.py | 160 ++++++++++-------- .../kv_cache_prefill/test_elf_npu2.cpp | 11 +- 2 files changed, 100 insertions(+), 71 deletions(-) diff --git a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py index b51db1895..3751cbb2f 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py +++ b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py @@ -141,12 +141,9 @@ def build_module( assert dv % dv_tile == 0, f"dv ({dv}) must be divisible by dv_tile/lkp ({dv_tile})" dv_chunks = dv // dv_tile enable_cache_writeback = enable_k_writeback or enable_v_writeback - if enable_cache_writeback: - assert dk == lkp and dv == lkp, ( - f"Interleaved KV cache write-back requires dk == dv == lkp, " - f"got dk={dk}, dv={dv}, lkp={lkp}. " - f"Use --no-k-writeback --no-v-writeback to disable." - ) + # Number of dk_tile-sized slots per chunk in the interleaved KV cache: + # dk_chunks K tiles + 1 V tile (V is always one dv_tile per launch iter) + cache_slots_per_chunk = dk_chunks + 1 if causal: assert lq == lk, f"Causal masking requires lq == lk, got lq={lq}, lk={lk}" assert lqp // num_q_tiles == lkp, ( @@ -214,13 +211,18 @@ def build_module( gp_l3_t = MemRefType.get([num_heads * dv_chunks, lq, dv_tile], bf16) # KV cache L3 types - if enable_k_writeback or enable_v_writeback: - # Interleaved KV cache: [num_kv_heads, num_chunks, 2, lkp, dk_tile] - # K at index 0, V at index 1 within each chunk pair + if enable_cache_writeback: + # Interleaved KV cache per chunk per (kv_head, dv_chunk) pair: + # [K_dk0, K_dk1, ..., K_dk(dk_chunks-1), V_dv_lz] + # Each slot is [lkp, dk_tile] = lkp*dk_tile elements. + # The outer dimension combines kv_head and dv_chunk (like V L3 layout). + # K is redundantly stored per dv_chunk (same data, different lz iters). + kv_slot_size = lkp * dk_tile # elements per slot + kv_chunk_stride = cache_slots_per_chunk * kv_slot_size + num_kv_cache_heads = num_kv_heads * dv_chunks kv_cache_l3_t = MemRefType.get( - [num_kv_heads * num_chunks * 2 * lkp * dk_tile], bf16 + [num_kv_cache_heads * num_chunks * kv_chunk_stride], bf16 ) - kv_chunk_stride = 2 * lkp * dk_tile # stride between K and V in a pair # Legacy separate caches (when writeback disabled, used as placeholders) k_cache_l3_t = MemRefType.get([num_kv_heads, lk, dk], bf16) v_cache_l3_t = MemRefType.get([num_kv_heads * dv_chunks, lk, dv_tile], bf16) @@ -576,26 +578,40 @@ def launch_body(*launch_args): gqa_group_size == 1 or head_local % gqa_group_size == 0 ) if is_first_in_gqa_group: - # KV head offset into the flat kv_cache buffer - kv_head_off = affine_apply( - AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get( - num_chunks * 2 * lkp * dk_tile + # KV cache head offset uses the same combined + # (kv_head * dv_chunks + lz) indexing as head_v_off, + # but with kv_chunk_stride instead of lk * dv_tile. + # head_v_off = (kv_head * dv_chunks + lz) * lk * dv_tile + # kv_head_off = (kv_head * dv_chunks + lz) * num_chunks * kv_chunk_stride + affine_map_kv_head = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_mul( + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(dv_chunks), ), - ) - ], - ), - [kv_head_idx], + AffineSymbolExpr.get(1), + ), + AffineConstantExpr.get( + num_chunks * kv_chunk_stride + ), + ) + ], + ) + kv_head_off = affine_apply( + affine_map_kv_head, [kv_head_idx, lz] ) # Use scf.for to allow the compiler to fold # consecutive CacheWB BDs into higher-dimensional # shim DMA BDs, avoiding BD exhaustion at scale. - c_cps2 = ConstantOp(index_type, chunks_per_stage * 2) + # Each chunk has cache_slots_per_chunk slots + # (dk_chunks K tiles + 1 V tile). + c_slots = ConstantOp( + index_type, chunks_per_stage * cache_slots_per_chunk + ) affine_map_kv_off = AffineMap.get( 0, 2, @@ -604,7 +620,7 @@ def launch_body(*launch_args): AffineSymbolExpr.get(0), AffineExpr.get_mul( AffineSymbolExpr.get(1), - AffineConstantExpr.get(lkp * dk_tile), + AffineConstantExpr.get(kv_slot_size), ), ) ], @@ -615,7 +631,7 @@ def launch_body(*launch_args): affine_map_add, [kv_head_off, ConstantOp(index_type, stage_base_val)], ) - for chunk_kv_iter in scf_range(0, c_cps2, 1): + for chunk_kv_iter in scf_range(0, c_slots, 1): wb_off = affine_apply( affine_map_kv_off, [stage_base, chunk_kv_iter], @@ -1563,26 +1579,40 @@ def apply_rope_ref(x, lut_slice): target_device="npu2", ) - # Build expected KV cache (interleaved: [K_c0, V_c0, K_c1, V_c1, ...]) + # Build expected KV cache (interleaved layout). + # Per (kv_head, dv_chunk) pair, per chunk: + # [K_dk0, K_dk1, ..., K_dk(dk_chunks-1), V_dv_lz] enable_cache_writeback = enable_k_writeback or enable_v_writeback + dk_chunks_host = dk // lkp + slot_size = lkp * lkp # lkp * dk_tile = lkp * lkp + slots_per_chunk = dk_chunks_host + 1 + chunk_stride = slots_per_chunk * slot_size if enable_cache_writeback: num_chunks_host = lk // lkp - kv_cache_size = num_kv_heads * num_chunks_host * 2 * lkp * dk + num_cache_heads = num_kv_heads * dv_chunks_host + kv_cache_size = num_cache_heads * num_chunks_host * chunk_stride expected_kv_cache = np.zeros(kv_cache_size, dtype=INPUT_DATATYPE) for h in range(num_kv_heads): - for c in range(num_chunks_host): - k_off = h * num_chunks_host * 2 * lkp * dk + c * 2 * lkp * dk - v_off = k_off + lkp * dk - # K chunk: RoPE'd K data, row-major [lkp, dk] - expected_kv_cache[k_off : k_off + lkp * dk] = k_roped[ - h, c * lkp : (c + 1) * lkp, : - ].flatten() - # V chunk: raw V data in transposed tile format [lkp, dv_tile] - # V was transposed to [num_kv_heads * dv_chunks, lk, dv_tile] - # For dv_chunks=1, this is just [num_kv_heads, lk, dv_tile] - expected_kv_cache[v_off : v_off + lkp * dk] = input_v[ - h, c * lkp : (c + 1) * lkp, : - ].flatten() + for dv_idx in range(dv_chunks_host): + cache_head = h * dv_chunks_host + dv_idx + for c in range(num_chunks_host): + base = ( + cache_head * num_chunks_host * chunk_stride + c * chunk_stride + ) + # K tiles: dk_chunks tiles of [lkp, dk_tile] + for dk_idx in range(dk_chunks_host): + k_off = base + dk_idx * slot_size + expected_kv_cache[k_off : k_off + slot_size] = k_roped[ + h, + c * lkp : (c + 1) * lkp, + dk_idx * lkp : (dk_idx + 1) * lkp, + ].flatten() + # V tile: 1 tile of [lkp, dv_tile] + v_off = base + dk_chunks_host * slot_size + v_head = h * dv_chunks_host + dv_idx + expected_kv_cache[v_off : v_off + slot_size] = input_v[ + v_head, c * lkp : (c + 1) * lkp, : + ].flatten() else: expected_k_cache = k_roped.copy() @@ -1642,43 +1672,35 @@ def apply_rope_ref(x, lut_slice): "(BFP16 emulation tolerance)" ) - # --- Output 1: KV cache (interleaved [K_c0, V_c0, K_c1, V_c1, ...]) --- + # --- Output 1: KV cache (interleaved) --- if enable_cache_writeback: kv_actual = actual_outputs[1].flatten() kv_mismatches = int(np.sum(kv_actual != expected_kv_cache)) kv_total = kv_actual.size - chunk_size = lkp * dk - num_chunks_total = num_kv_heads * (lk // lkp) print(f"Output 1 (KV cache): mismatches={kv_mismatches}/{kv_total}") if kv_mismatches > 0: print(f"FAIL: KV cache has {kv_mismatches} mismatches") - for h in range(num_kv_heads): + for ch in range(num_cache_heads): for c in range(lk // lkp): - k_off = h * (lk // lkp) * 2 * chunk_size + c * 2 * chunk_size - v_off = k_off + chunk_size - k_m = int( - np.sum( - kv_actual[k_off : k_off + chunk_size] - == expected_kv_cache[k_off : k_off + chunk_size] - ) - ) - v_m = int( - np.sum( - kv_actual[v_off : v_off + chunk_size] - == expected_kv_cache[v_off : v_off + chunk_size] + base = ch * (lk // lkp) * chunk_stride + c * chunk_stride + # Check each slot + for s in range(slots_per_chunk): + s_off = base + s * slot_size + s_m = int( + np.sum( + kv_actual[s_off : s_off + slot_size] + == expected_kv_cache[s_off : s_off + slot_size] + ) ) - ) - k_label = ( - "EXACT" if k_m == chunk_size else f"{k_m}/{chunk_size}" - ) - v_label = ( - "EXACT" if v_m == chunk_size else f"{v_m}/{chunk_size}" - ) - if k_m < chunk_size or v_m < chunk_size: - print(f" h={h} c={c}: K={k_label}, V={v_label}") + if s_m < slot_size: + label = "K" if s < dk_chunks_host else "V" + print( + f" ch={ch} c={c} slot={s}({label}): " + f"{s_m}/{slot_size}" + ) failed = True else: - print("PASS: KV cache matches expected (K=RoPE'd K, V=raw V)") + print("PASS: KV cache matches expected") else: print("Output 1 (KV cache): SKIPPED (cache write-back disabled)") diff --git a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp index 2c8713fa2..e9cbf2690 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp +++ b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp @@ -123,8 +123,15 @@ int main(int argc, const char *argv[]) { // Create buffer objects using xrt::ext::bo (declared as xrt::bo type) // Kernel signature: attention_bf16(Q, K, V, Output, KV_cache) - // KV cache is interleaved: [K_c0, V_c0, K_c1, V_c1, ...] - size_t KV_CACHE_SIZE = (size_t)num_heads * lk * dk * 2 * sizeof(DATATYPE); + // KV cache is interleaved: [K_dk0, ..., K_dk(N-1), V_dv_lz] per chunk + // Per (kv_head, dv_chunk) pair: num_chunks * (dk_chunks + 1) * lkp * dk_tile + int lkp = 64; // tile size + int dk_chunks = dk / lkp; + int dv_chunks = dv / lkp; + int slots_per_chunk = dk_chunks + 1; + int num_chunks = lk / lkp; + size_t KV_CACHE_SIZE = (size_t)num_heads * dv_chunks * num_chunks * + slots_per_chunk * lkp * lkp * sizeof(DATATYPE); xrt::bo bo_q = xrt::ext::bo{device, Q_SIZE}; xrt::bo bo_k = xrt::ext::bo{device, K_SIZE}; xrt::bo bo_v = xrt::ext::bo{device, V_SIZE}; From 77e6a1b2aa26f27c7a020dde7da3eca30b97cabb Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 8 Apr 2026 11:39:35 -0700 Subject: [PATCH 4/6] Address PR review comments - Remove run_test.sh with hardcoded machine-specific paths - Fix lit test CHECK pattern: OVERALL: PASSED (not PASS!) - Fix misleading RoPE message in C++ profiler: host pre-rotation - Add missing C++ standard headers (chrono, cstring, cstdlib) - Document GQA duplicate-write behavior for gqa_group_size > unroll Co-Authored-By: Claude Opus 4.6 (1M context) --- .../kv_cache_prefill/attn_npu2.py | 4 ++++ .../run_npu2_makefile_peano_elf.lit | 2 +- .../kv_cache_prefill/run_test.sh | 20 ------------------- .../kv_cache_prefill/test_elf_npu2.cpp | 5 ++++- 4 files changed, 9 insertions(+), 22 deletions(-) delete mode 100755 programming_examples/flash_attention/kv_cache_prefill/run_test.sh diff --git a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py index 3751cbb2f..165a58e4c 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py +++ b/programming_examples/flash_attention/kv_cache_prefill/attn_npu2.py @@ -574,6 +574,10 @@ def launch_body(*launch_args): # this K,V,K,V get ordering. # ---------------------------------------------------------- if enable_cache_writeback: + # GQA: skip duplicate writes within the same unroll group. + # Note: for gqa_group_size > num_heads_per_unroll, heads + # in different unroll groups sharing the same KV head will + # still write redundantly (same data, harmless). is_first_in_gqa_group = ( gqa_group_size == 1 or head_local % gqa_group_size == 0 ) diff --git a/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit b/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit index 6f4a2d210..0af0f7e72 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit +++ b/programming_examples/flash_attention/kv_cache_prefill/run_npu2_makefile_peano_elf.lit @@ -7,4 +7,4 @@ // RUN: cd test_npu2_peano_elf // RUN: make -f %S/Makefile clean // RUN: make -f %S/Makefile run PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s -// CHECK: PASS! +// CHECK: OVERALL: PASSED diff --git a/programming_examples/flash_attention/kv_cache_prefill/run_test.sh b/programming_examples/flash_attention/kv_cache_prefill/run_test.sh deleted file mode 100755 index d11041d07..000000000 --- a/programming_examples/flash_attention/kv_cache_prefill/run_test.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -set -e - -# Source XRT first -source /opt/xilinx/xrt/setup.sh 2>/dev/null || true - -# Activate sandbox venv (after XRT to preserve PATH) -source /home/strixminipc/new_session/mlir-air/sandbox/bin/activate - -# Set paths - build/bin first so we get the C++ aircc binary -export PATH=/home/strixminipc/new_session/mlir-air/build/bin:/home/strixminipc/new_session/mlir-air/mlir-aie/install/bin:$PATH -export PYTHONPATH=/home/strixminipc/new_session/mlir-air/build/python:/home/strixminipc/new_session/mlir-air/mlir-aie/install/python:$PYTHONPATH -export LD_LIBRARY_PATH=/home/strixminipc/new_session/mlir-air/build/lib:/home/strixminipc/new_session/mlir-air/mlir-aie/install/lib:$LD_LIBRARY_PATH - -# Peano compiler (llvm-aie) installed as pip package in sandbox -export PEANO_INSTALL_DIR=/home/strixminipc/new_session/mlir-air/sandbox/lib/python3.13/site-packages/llvm-aie - -cd /home/strixminipc/new_session/mlir-air/programming_examples/flash_attention/kv_cache_prefill/build_peano - -exec python3 ../attn_npu2.py "$@" diff --git a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp index e9cbf2690..f55b17f2a 100644 --- a/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp +++ b/programming_examples/flash_attention/kv_cache_prefill/test_elf_npu2.cpp @@ -8,8 +8,11 @@ #include "cxxopts.hpp" #include +#include #include #include +#include +#include #include #include #include @@ -162,7 +165,7 @@ int main(int argc, const char *argv[]) { << V_SIZE << " bytes)" << std::endl; std::cout << " Output: [" << num_heads << "x" << lq << "x" << dv << "] (" << OUTPUT_SIZE << " bytes)" << std::endl; - std::cout << " RoPE: computed in-kernel (sincos)" << std::endl; + std::cout << " RoPE: host pre-rotation" << std::endl; if (verbosity >= 1) std::cout << "Writing data into buffer objects.\n"; From 15e9cf60a85fc1cc1fd2b14dd072e9f9239a0be3 Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 8 Apr 2026 11:53:12 -0700 Subject: [PATCH 5/6] Remove compiler changes (moved to PR #1515) The lock placement fix for outbound puts sharing a staging buffer has been moved to a separate PR (#1515). This PR now contains only the programming example changes. The example requires PR #1515 to be merged first for V cache write-back to work correctly. Co-Authored-By: Claude Opus 4.6 (1M context) --- mlir/lib/Conversion/AIRToAIEPass.cpp | 134 +++++++++++++++++---------- 1 file changed, 83 insertions(+), 51 deletions(-) diff --git a/mlir/lib/Conversion/AIRToAIEPass.cpp b/mlir/lib/Conversion/AIRToAIEPass.cpp index 8349d9cea..24e5f6c53 100644 --- a/mlir/lib/Conversion/AIRToAIEPass.cpp +++ b/mlir/lib/Conversion/AIRToAIEPass.cpp @@ -2331,6 +2331,11 @@ struct SpecializeChannelBundlePattern for (auto put : channelPuts) { auto indices_uint = air::convertVecOfConstIndexToVecOfUInt(put.getIndices()); + if (indices_uint.empty() && !put.getIndices().empty() && iter == 0) + put->emitWarning( + "channel bundle indices cannot be resolved to compile-time " + "constants; this channel put will be replaced with " + "air.wait_all, which may cause data loss"); if (areIdenticalVectors(indices_uint, position)) { // Found channel put for this channel rewriter.setInsertionPoint(put); @@ -2349,6 +2354,11 @@ struct SpecializeChannelBundlePattern for (auto get : channelGets) { auto indices_uint = air::convertVecOfConstIndexToVecOfUInt(get.getIndices()); + if (indices_uint.empty() && !get.getIndices().empty() && iter == 0) + get->emitWarning( + "channel bundle indices cannot be resolved to compile-time " + "constants; this channel get will be replaced with " + "air.wait_all, which may cause data loss"); if (areIdenticalVectors(indices_uint, position)) { // Found channel get for this channel rewriter.setInsertionPoint(get); @@ -3013,57 +3023,47 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { return AIE::PacketFlowOp(); // Only air.channel_interface ops support // packet-flow routing. - // Determine if this is a shim flow by checking if EITHER source OR - // destination tile is a shim tile. This must be consistent with - // placeDMAChannelsAndRouteFlows which uses the same criteria. - auto sourceTileOp = source.getDefiningOp(); - bool sourceIsShim = sourceTileOp && sourceTileOp.isShimNOCorPLTile(); - - // Check if the destination involves a shim tile by examining the memcpy's - // memory spaces (L3 memory space indicates shim tile involvement) - bool destIsShim = false; - if (auto srcMemref = memcpyOp.getSrcMemref()) { - auto memrefTy = dyn_cast_if_present(srcMemref.getType()); - if (memrefTy && air::isL3(memrefTy)) - destIsShim = true; - } - if (auto dstMemref = memcpyOp.getDstMemref()) { - auto memrefTy = dyn_cast_if_present(dstMemref.getType()); - if (memrefTy && air::isL3(memrefTy)) - destIsShim = true; - } - - bool isShimFlow = sourceIsShim || destIsShim; - - // Select the appropriate flow map based on whether this involves shim tiles - const SetVector &flowMap = - isShimFlow ? shimFlowOpToFlowIdMap : intraDeviceFlowOpToFlowIdMap; - - // Convert flowMap from Operation pointers to channel symbol names. + // Convert a flow map from Operation pointers to channel symbol names. // This is necessary because air.channel declarations are duplicated // under aie.device op and its parent module op, requiring symbol-based // matching. - std::vector flowOpStringsToFlowIdMap; - for (auto op : flowMap) { - auto flowChanOp = dyn_cast_if_present(op); - if (!flowChanOp) { - flowOpStringsToFlowIdMap.push_back(""); - continue; + auto buildFlowIdMap = + [](const SetVector &fmap) -> std::vector { + std::vector result; + for (auto op : fmap) { + auto flowChanOp = dyn_cast_if_present(op); + if (!flowChanOp) { + result.push_back(""); + continue; + } + result.push_back(flowChanOp.getSymName().str()); } - flowOpStringsToFlowIdMap.push_back(flowChanOp.getSymName().str()); - } + return result; + }; - // Find the flowID by matching the channel name - auto it = - llvm::find(flowOpStringsToFlowIdMap, chanIfOp.getChanName().str()); - if (it == flowOpStringsToFlowIdMap.end()) { - return AIE::PacketFlowOp(); + // Search both flow maps by channel name. Channel names are unique symbols, + // so each channel appears in exactly one map. We search the shim (device- + // host) map first, then the intra-device map. + // + // Note: we cannot reliably determine which map to search from the memcpy + // op alone, because ChannelPutOp::getDstMemref() and ChannelGetOp:: + // getSrcMemref() return nullptr by design (the other end of a channel op + // is implicit via the channel symbol). Searching both maps directly is + // simpler and always correct. + std::string chanName = chanIfOp.getChanName().str(); + + for (const auto &flowMap : {std::cref(shimFlowOpToFlowIdMap), + std::cref(intraDeviceFlowOpToFlowIdMap)}) { + auto flowStrings = buildFlowIdMap(flowMap.get()); + auto it = llvm::find(flowStrings, chanName); + if (it != flowStrings.end()) { + int flowID = std::distance(flowStrings.begin(), it); + return findPacketFlowOp(source, sourceBundle, sourceChannel, + /*checkFlowID=*/true, flowID); + } } - int flowID = std::distance(flowOpStringsToFlowIdMap.begin(), it); - // Search for the packet flow with matching source and flowID - return findPacketFlowOp(source, sourceBundle, sourceChannel, - /*checkFlowID=*/true, flowID); + return AIE::PacketFlowOp(); } /// Query an existing packet flow operation from the runtime function. @@ -4053,8 +4053,8 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { // Get index to metadataArray based on channel indices. auto iter = air::getIndexToMetadataArrayFromChannelIndices(ci); if (!iter) { - ci->emitOpError("channel indices failed to convert to convert to " - "metadataArray index."); + ci->emitOpError( + "channel indices failed to convert to metadataArray index."); return failure(); } // Get metadata from metadataArray. @@ -4678,6 +4678,25 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { alloc = memcpyOpIf.getSrcMemref(); } + // Detect if multiple outbound puts in this DMA allocation share the same + // source buffer. When true, use per-put interleaved lock placement to + // prevent the second put from overwriting the buffer before the DMA + // finishes reading the first put's data. + bool sharedStagingBuffer = false; + if (!tileInbound.value() && isa(alloc.getDefiningOp()) && + dma_alloc.value().memcpyOps.size() > 1) { + int sameBufCount = 0; + for (auto *op : dma_alloc.value().memcpyOps) { + if (auto other = dyn_cast_if_present(op)) { + auto otherInbound = isTileInbound(other, air::MemorySpace::L1); + if (succeeded(otherInbound) && !otherInbound.value() && + other.getSrcMemref() == alloc) + sameBufCount++; + } + } + sharedStagingBuffer = sameBufCount > 1; + } + if (auto bco = dyn_cast_if_present( alloc.getDefiningOp())) builder.setInsertionPoint(bco.getOperand().getDefiningOp()); @@ -4685,7 +4704,19 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { builder.setInsertionPoint(alloc.getDefiningOp()); else if (!tileInbound.value() && isa(alloc.getDefiningOp())) { - builder.setInsertionPoint(memcpyOpIf); + if (sharedStagingBuffer) { + // Interleaved mode: acquire immediately before this specific put, so + // the core waits for the DMA to finish reading the previous put's + // data before overwriting the buffer. + builder.setInsertionPoint(memcpyOpIf); + } else { + auto br = dyn_cast_if_present( + memcpyOpIf->getBlock()->getTerminator()); + if (br) + builder.setInsertionPointToStart(br.getDest()); + else + builder.setInsertionPointToStart(memcpyOpIf->getBlock()); + } } else builder.setInsertionPoint(memcpyOpIf); @@ -4696,10 +4727,11 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { lockAqValue); // Try to find the end of lifetime for the data copied by memcpyOpIf, and - // put the unlock. For outbound puts from AIE::BufferOp, release - // immediately after the put to enable interleaved operation when multiple - // puts share the same staging buffer. - if (!tileInbound.value() && isa(alloc.getDefiningOp())) { + // put the unlock. + if (sharedStagingBuffer) { + // Interleaved mode: release rlock immediately after the put so the DMA + // can read the buffer before the next put overwrites it. The next put's + // acquire(wlock) will block until the DMA completes reading. builder.setInsertionPointAfter(memcpyOpIf); AIE::UseLockOp::create(builder, memcpyOpIf->getLoc(), relLockOp, AIE::LockAction::Release, lockRelValue); From fdc78b29fee597914f80aad8ac34088629884f88 Mon Sep 17 00:00:00 2001 From: erweiw Date: Wed, 8 Apr 2026 14:41:28 -0700 Subject: [PATCH 6/6] Add KV cache prefill to operator dashboard Register flash_attention/kv_cache_prefill in the programming examples dashboard generator. Shows as NPU2-only (green) based on the lit test. Co-Authored-By: Claude Opus 4.6 (1M context) --- programming_examples/generate_readme.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/programming_examples/generate_readme.py b/programming_examples/generate_readme.py index 9c223f45d..592a51aef 100644 --- a/programming_examples/generate_readme.py +++ b/programming_examples/generate_readme.py @@ -216,6 +216,12 @@ "path": "flash_attention/kernel_fusion_based", "datatypes": "bf16", }, + { + "category": "Attention", + "name": "Flash Attention + KV Cache Prefill", + "path": "flash_attention/kv_cache_prefill", + "datatypes": "bf16", + }, { "category": "Data Movement", "name": "Passthrough (DMA)",