From 580a0e77e04213d1be771e52d8f7cb35583d529f Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 10 Apr 2026 23:40:19 -0700 Subject: [PATCH] Add packet-switched flash attention programming example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Recover the original packet-switched flash attention design as a standalone example. This design uses dma_packet channels to time-multiplex Q and K data through shared compute tile S2MM DMA channels via hardware packet routing in the stream switch. Channel routing: L2ToL1Chan1 (Q): dma_packet — broadcast to all compute tiles L2ToL1Chan2 (K): dma_packet — broadcast to all compute tiles L2ToL1Chan3 (V): dma_stream — circuit-switched per cascade stage Includes both NPU2 (AIE2P, attn.py + attn_pkt.cc) and NPU1 (AIE2, attn_npu1.py + attn_npu1.cc) variants. The NPU1 variant reuses the kernel from kernel_fusion_based with k-major B-block indexing and adapted DMA layouts for mmul<4,8,4>. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../flash_attention/packet_switched/Makefile | 101 ++ .../flash_attention/packet_switched/attn.py | 1387 +++++++++++++++++ .../packet_switched/attn_npu1.cc | 1012 ++++++++++++ .../packet_switched/attn_npu1.py | 1341 ++++++++++++++++ .../packet_switched/attn_pkt.cc | 680 ++++++++ .../run_npu1_makefile_peano.lit | 10 + .../run_npu2_makefile_peano.lit | 10 + .../flash_attention/packet_switched/zero.cc | 47 + 8 files changed, 4588 insertions(+) create mode 100644 programming_examples/flash_attention/packet_switched/Makefile create mode 100644 programming_examples/flash_attention/packet_switched/attn.py create mode 100644 programming_examples/flash_attention/packet_switched/attn_npu1.cc create mode 100644 programming_examples/flash_attention/packet_switched/attn_npu1.py create mode 100644 programming_examples/flash_attention/packet_switched/attn_pkt.cc create mode 100644 programming_examples/flash_attention/packet_switched/run_npu1_makefile_peano.lit create mode 100644 programming_examples/flash_attention/packet_switched/run_npu2_makefile_peano.lit create mode 100644 programming_examples/flash_attention/packet_switched/zero.cc diff --git a/programming_examples/flash_attention/packet_switched/Makefile b/programming_examples/flash_attention/packet_switched/Makefile new file mode 100644 index 000000000..4f19aa201 --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/Makefile @@ -0,0 +1,101 @@ +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +# +# Packet-switched flash attention example. +# Uses dma_packet channels for Q and K routing through shared +# compute tile S2MM DMA channels. +# Supports both NPU2 (AIE2P) and NPU1 (AIE2). +# +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) + +# Attention parameters +LK ?= 512 +LKP ?= 64 +LQ ?= 512 +LQP ?= 256 +DK ?= 64 +DV ?= 64 +NUM_HEADS ?= 2 +NUM_KV_HEADS ?= $(NUM_HEADS) +VAL_RANGE ?= 3 + +# Derived: kernel tile size = LQP / num_q_tiles (4) +NUM_Q_TILES ?= 4 +LQP_TILE := $(shell echo $$(($(LQP) / $(NUM_Q_TILES)))) + +# Determine build dir based on whether PEANO_INSTALL_DIR is set +ifdef PEANO_INSTALL_DIR + BUILD_DIR := build_peano +else + BUILD_DIR := build_chess +endif + +AIEOPT_DIR = $(shell realpath $(dir $(shell which aie-opt))/..) +WARNING_FLAGS = -Wno-parentheses -Wno-attributes -Wno-macro-redefined -Wno-empty-body +PEANOWRAP2P_FLAGS = -O2 -std=c++20 --target=aie2p-none-unknown-elf ${WARNING_FLAGS} -DNDEBUG -I ${AIEOPT_DIR}/include +PEANOWRAP2_FLAGS = -O2 -std=c++20 --target=aie2-none-unknown-elf ${WARNING_FLAGS} -DNDEBUG -I ${AIEOPT_DIR}/include + +# ============================================================================ +# NPU2 (AIE2P) targets — default +# ============================================================================ + +all: run + +print: + ${powershell} python3 ${srcdir}/attn.py -p --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) + +run: compile-kernel + mkdir -p $(BUILD_DIR) + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && ${powershell} python3 ${srcdir}/attn.py --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) --val-range $(VAL_RANGE) $(EXTRA_PY_FLAGS) + +compile-kernel: + mkdir -p $(BUILD_DIR) + @if [ -n "$(PEANO_INSTALL_DIR)" ]; then \ + echo "Detected PEANO_INSTALL_DIR from environment: $(PEANO_INSTALL_DIR)"; \ + if [ -x "$(PEANO_INSTALL_DIR)/bin/clang++" ]; then \ + echo "Using clang++ from PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR)"; \ + $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2P_FLAGS} -DBIT_WIDTH=8 -c ${srcdir}/attn_pkt.cc -o $(BUILD_DIR)/attn_pkt.o -Dlqp=$(LQP_TILE) -Dlkp=$(LKP) -Ddk=$(DK) -Ddv=$(DV) -DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16 -DROUND_CONV_EVEN $(EXTRA_KERNEL_FLAGS); \ + else \ + echo "Error: invalid PEANO_INSTALL_DIR, clang++ not found."; \ + exit 1; \ + fi; \ + elif command -v xchesscc_wrapper >/dev/null 2>&1; then \ + echo "Using xchesscc_wrapper from PATH"; \ + cd $(BUILD_DIR) && ${powershell} xchesscc_wrapper aie2p -c ${srcdir}/attn_pkt.cc -o attn_pkt.o -Dlqp=$(LQP_TILE) -Dlkp=$(LKP) -Ddk=$(DK) -Ddv=$(DV); \ + else \ + echo "Error: Neither PEANO_INSTALL_DIR nor xchesscc_wrapper found."; \ + exit 1; \ + fi + +# ============================================================================ +# NPU1 (AIE2) targets +# ============================================================================ + +print-npu1: + ${powershell} python3 ${srcdir}/attn_npu1.py -p --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) + +run-npu1: compile-kernel-npu1 + mkdir -p $(BUILD_DIR) + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && ${powershell} python3 ${srcdir}/attn_npu1.py --lk $(LK) --lkp $(LKP) --lq $(LQ) --lqp $(LQP) --dk $(DK) --dv $(DV) --num-heads $(NUM_HEADS) --num-kv-heads $(NUM_KV_HEADS) --val-range $(VAL_RANGE) $(EXTRA_PY_FLAGS) + +compile-kernel-npu1: + mkdir -p $(BUILD_DIR) + @if [ -n "$(PEANO_INSTALL_DIR)" ]; then \ + echo "Detected PEANO_INSTALL_DIR from environment: $(PEANO_INSTALL_DIR)"; \ + if [ -x "$(PEANO_INSTALL_DIR)/bin/clang++" ]; then \ + echo "Using clang++ from PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) for AIE2 (NPU1)"; \ + $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2_FLAGS} -DBIT_WIDTH=8 -c ${srcdir}/attn_npu1.cc -o $(BUILD_DIR)/attn_npu1.o -I${srcdir}/../dataflow_based -Dlqp=$(LQP_TILE) -Dlkp=$(LKP) -Ddk=$(LKP) -Ddk_full=$(DK) -Ddv=$(LKP) -Ddv_full=$(DV) $(EXTRA_KERNEL_FLAGS); \ + else \ + echo "Error: invalid PEANO_INSTALL_DIR, clang++ not found."; \ + exit 1; \ + fi; \ + elif command -v xchesscc_wrapper >/dev/null 2>&1; then \ + echo "Using xchesscc_wrapper from PATH for AIE2 (NPU1)"; \ + cd $(BUILD_DIR) && ${powershell} xchesscc_wrapper aie2 -c ${srcdir}/attn_npu1.cc -o attn_npu1.o -I${srcdir}/../dataflow_based -Dlqp=$(LQP_TILE) -Dlkp=$(LKP) -Ddk=$(LKP) -Ddk_full=$(DK) -Ddv=$(LKP) -Ddv_full=$(DV); \ + else \ + echo "Error: Neither PEANO_INSTALL_DIR nor xchesscc_wrapper found."; \ + exit 1; \ + fi + +clean: + rm -rf $(BUILD_DIR) __pycache__ diff --git a/programming_examples/flash_attention/packet_switched/attn.py b/programming_examples/flash_attention/packet_switched/attn.py new file mode 100644 index 000000000..3124f8f9f --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/attn.py @@ -0,0 +1,1387 @@ +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +"""Flash attention with packet-switched Q/K routing (NPU2 / AIE2P). + +This is the original flash attention design that uses packet-switched DMA +channels (channel_type="dma_packet") to time-multiplex Q and K data through +shared compute tile S2MM DMA channels. The stream switch demultiplexes +incoming packets to the correct tile based on packet IDs. + +Channel routing: + L2ToL1Chan1 (Q): dma_packet — broadcast to [num_q_tiles, num_cascade_stages] + L2ToL1Chan2 (K): dma_packet — broadcast to [num_q_tiles, num_cascade_stages] + L2ToL1Chan3 (V): dma_stream — circuit-switched per cascade stage + +This design was later replaced by a memtile-relayed selective-capture design +(see kernel_fusion_based/) which uses circuit-switched routing with software- +based Q selection. This example preserves the packet-switched variant as a +reference implementation. +""" + +import argparse +from math import cos, sin, sqrt, exp +import numpy as np + +import air +from air.ir import * +from air.dialects.affine import apply as affine_apply +from air.dialects.air import * +from air.dialects.arith import ConstantOp +from air.dialects.memref import AllocOp, CollapseShapeOp, DeallocOp, load, store +from air.dialects.func import FuncOp, CallOp +from air.dialects.scf import for_, yield_ +from air.dialects import scf, affine, arith + +range_ = for_ + + +@module_builder +def build_module( + lk=12288, + lkp=96, + lq=512, + lqp=128, + dk=64, + dv=64, + num_q_tiles=4, + num_cascade_stages=4, + num_heads=12, + num_kv_heads=None, + causal=False, +): + """Build the attention module using Python bindings + + Args: + lk: Total sequence length for K/V matrices (default: 12288) + lkp: Chunk size for K/V processing per AIE tile (default: 96) + lq: Total sequence length for Q matrix (default: 512) + lqp: Chunk size for Q processing per launch iteration (default: 128) + dk: Key dimension (default: 64) + dv: Value dimension (default: 64) + num_q_tiles: Number of tiles to partition Q chunk (lqp) into (default: 4) + num_cascade_stages: Number of cascade pipeline stages (default: 4) + num_heads: Number of Q attention heads (default: 12) + num_kv_heads: Number of K/V heads (default: num_heads for MHA, < num_heads for GQA) + causal: Enable causal masking (default: False) + """ + if num_kv_heads is None: + num_kv_heads = num_heads # MHA: every Q head has its own KV head + + # Validate divisibility requirements + assert lq % lqp == 0, f"lq ({lq}) must be divisible by lqp ({lqp})" + assert ( + lqp % num_q_tiles == 0 + ), f"lqp ({lqp}) must be divisible by num_q_tiles ({num_q_tiles})" + assert lk % lkp == 0, f"lk ({lk}) must be divisible by lkp ({lkp})" + assert ( + lk % (lkp * num_cascade_stages) == 0 + ), f"lk ({lk}) must be divisible by lkp * num_cascade_stages ({lkp * num_cascade_stages})" + tile_size_q_check = lqp // num_q_tiles + enable_shared_buffers = lkp == dk and tile_size_q_check <= lkp + if causal: + assert lq == lk, f"Causal masking requires lq == lk, got lq={lq}, lk={lk}" + assert lkp == dk, ( + f"Causal masking requires lkp == dk (enable_shared_buffers) for " + f"the prefix+suffix BD collapse to produce infinite-loop DMAs " + f"(no PDI reset between iterations). Got lkp={lkp}, dk={dk}." + ) + tile_size_q = lqp // num_q_tiles + assert ( + tile_size_q == lkp + ), f"Causal masking requires tile_size_q == lkp, got {tile_size_q} vs {lkp}" + assert ( + num_heads % 2 == 0 + ), f"num_heads ({num_heads}) must be divisible by 2 (segment unroll constraint)" + assert num_kv_heads > 0, "num_kv_heads must be positive" + assert ( + num_heads % num_kv_heads == 0 + ), f"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + gqa_group_size = num_heads // num_kv_heads + + bf16 = Type.parse("bf16") + i32 = IntegerType.get_signless(32) + index_type = IndexType.get() + + # Architecture-specific matrix multiplication dimensions + mmul_mkn = [8, 8, 8] # For aie2p + mmul_m, mmul_k, mmul_n = mmul_mkn + + # Hardware constraint: max 2 heads per segment unroll + num_heads_per_unroll = 2 + num_head_groups = num_heads // num_heads_per_unroll + + # Derived parameters + num_chunks = lk // lkp + chunks_per_stage = num_chunks // num_cascade_stages + num_lq_iters = lq // lqp # Total Q iterations + # Q iteration at launch level for both causal and non-causal. + # Keeping Q at launch level avoids DMA task ordering conflicts: when Q + # iterates on-device, Q and K share the same compute-tile S2MM channel, + # and getRepeatCounts groups them into sequential tasks [Q×N, K×M] + # instead of interleaved [Q, K×M, Q, K×M, ...], causing deadlock. + # For causal masking, the launch Q index is threaded through to the herd + # body for the block index computation. + launch_lq_iters = num_lq_iters + device_lq_iters = 1 + tile_size_q = lqp // num_q_tiles # Tile size within each lqp chunk + + # Memory spaces: L1 = 2 : i32, L2 = 1 : i32 + l1_space = IntegerAttr.get(i32, 2) # L1 uses memory space 2 + l2_space = IntegerAttr.get(i32, 1) # L2 uses memory space 1 + + # L1 MemRefTypes (memory space 2 : i32) - used in herd bodies + memref_lqp_dv_l1 = MemRefType.get([tile_size_q, dk], bf16, memory_space=l1_space) + memref_lqp_l1 = MemRefType.get([tile_size_q, 1], bf16, memory_space=l1_space) + memref_lqp_lkp_l1 = MemRefType.get([tile_size_q * lkp], bf16, memory_space=l1_space) + memref_dv_lkp_l1 = MemRefType.get([lkp, dk], bf16, memory_space=l1_space) + memref_g_shared_l1 = MemRefType.get([tile_size_q, lkp], bf16, memory_space=l1_space) + + # L2 MemRefTypes (memory space 1 : i32) - segment allocations + memref_lqp_dk_l2 = MemRefType.get([tile_size_q, dk], bf16, memory_space=l2_space) + memref_dk_lkp_l2 = MemRefType.get([lkp, dk], bf16, memory_space=l2_space) + memref_lkp_dv_l2 = MemRefType.get([lkp, dk], bf16, memory_space=l2_space) + memref_output_lqp_dv_l2 = MemRefType.get( + [lqp, dk], bf16, memory_space=l2_space + ) # Per-iteration output buffer + + # L3 MemRefTypes (no memory space annotation = default L3) - with head dimension + memref_input_q_lq_dk = MemRefType.get([num_heads, lq, dk], bf16) + memref_output_lq_dv = MemRefType.get([num_heads, lq, dk], bf16) + memref_input_k_dk_lk = MemRefType.get([num_kv_heads, lk, dk], bf16) + memref_input_v_lk_dv = MemRefType.get([num_kv_heads, lk, dk], bf16) + memref_input_m_lq_lk = MemRefType.get([num_heads, lq, lk], bf16) + + # Helper function to create external function declarations + def external_func(name, inputs, outputs=None, link_with=None, visibility="private"): + if outputs is None: + outputs = [] + func_type = FunctionType.get(inputs, outputs) + func = FuncOp(name=name, type=func_type, visibility=visibility) + func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + if link_with: + func.attributes["link_with"] = StringAttr.get(link_with) + return func + + # External function declarations + external_func("zero_fill_gp_bf16", [memref_lqp_dv_l1], link_with="attn_pkt.o") + external_func("zero_fill_sp_bf16", [memref_lqp_l1], link_with="attn_pkt.o") + external_func("zero_fill_g_bf16", [memref_lqp_lkp_l1], link_with="attn_pkt.o") + external_func("neg_inf_fill_up_bf16", [memref_lqp_l1], link_with="attn_pkt.o") + external_func( + "matmul_a_b_bf16", + [memref_lqp_dv_l1, memref_dv_lkp_l1, memref_lqp_lkp_l1], + link_with="attn_pkt.o", + ) + external_func( + "matmul_g_b_bf16", + [memref_lqp_lkp_l1, memref_dv_lkp_l1, memref_lqp_dv_l1], + link_with="attn_pkt.o", + ) + external_func( + "max_g_bf16", [memref_lqp_lkp_l1, memref_lqp_l1], link_with="attn_pkt.o" + ) + external_func( + "fused_softmax", + [memref_lqp_lkp_l1, memref_lqp_l1, memref_lqp_l1, memref_lqp_l1], + link_with="attn_pkt.o", + ) + external_func( + "maximum_up_u_bf16", [memref_lqp_l1, memref_lqp_l1], link_with="attn_pkt.o" + ) + external_func( + "exp_g_minus_u", [memref_lqp_l1, memref_lqp_lkp_l1], link_with="attn_pkt.o" + ) + external_func( + "exp_up_minus_u", + [memref_lqp_l1, memref_lqp_l1, memref_lqp_l1], + link_with="attn_pkt.o", + ) + external_func("mul_r_gp", [memref_lqp_l1, memref_lqp_dv_l1], link_with="attn_pkt.o") + external_func("sum_g", [memref_lqp_lkp_l1, memref_lqp_l1], link_with="attn_pkt.o") + external_func( + "accum_sp_r_s", + [memref_lqp_l1, memref_lqp_l1, memref_lqp_l1], + link_with="attn_pkt.o", + ) + external_func( + "vector_copy_32elems", + [i32, memref_lqp_l1, memref_lqp_l1], + link_with="attn_pkt.o", + ) + external_func( + "copy_tile", [memref_dv_lkp_l1, memref_lqp_dv_l1], link_with="attn_pkt.o" + ) + external_func( + "div_gp_sp", [memref_lqp_l1, memref_lqp_dv_l1], link_with="attn_pkt.o" + ) + external_func( + "vector_copy_swizzle_elems", + [i32, memref_lqp_lkp_l1, memref_lqp_lkp_l1], + link_with="attn_pkt.o", + ) + external_func( + "vector_copy_unswizzle_elems", + [i32, memref_lqp_lkp_l1, memref_lqp_lkp_l1], + link_with="attn_pkt.o", + ) + external_func( + "add_gp_g", [memref_lqp_dv_l1, memref_lqp_dv_l1], link_with="attn_pkt.o" + ) + # Local i32 buffer for passing block indices to apply_causal_mask + # (unconditional i32 stores, kernel handles conditionals) + memref_2xi32_l1 = MemRefType.get([2], i32, memory_space=l1_space) + if causal: + external_func( + "apply_causal_mask", + [memref_lqp_lkp_l1, i32, i32], + link_with="attn_pkt.o", + ) + + # Channel declarations - use num_heads_per_unroll (2) for segment unroll + Channel("L3ToL2Chan1", size=[num_heads_per_unroll, num_cascade_stages]) + Channel("L3ToL2Chan2", size=[num_heads_per_unroll, num_cascade_stages]) + chan_l2_to_l1_2 = Channel( + "L2ToL1Chan2", + size=[1, num_cascade_stages], + broadcast_shape=[num_q_tiles, num_cascade_stages], + ) + chan_l2_to_l1_2.attributes["channel_type"] = StringAttr.get("dma_packet") + if not enable_shared_buffers: + chan_l2_to_l1_1 = Channel( + "L2ToL1Chan1", + size=[num_q_tiles, 1], + broadcast_shape=[num_q_tiles, num_cascade_stages], + ) + chan_l2_to_l1_1.attributes["channel_type"] = StringAttr.get("dma_packet") + chan_l2_to_l1_3 = Channel( + "L2ToL1Chan3", + size=[1, num_cascade_stages], + broadcast_shape=[num_q_tiles, num_cascade_stages], + ) + Channel("L1ToL2Chan1", size=[num_q_tiles, 1]) + Channel("L2ToL3Chan1", size=[num_heads_per_unroll]) + chan_cascade = Channel("cascade", size=[num_q_tiles, num_cascade_stages - 1]) + chan_cascade.attributes["channel_type"] = StringAttr.get("cascade") + + # Main attention function + @FuncOp.from_py_func( + memref_input_q_lq_dk, + memref_input_k_dk_lk, + memref_input_v_lk_dv, + memref_input_m_lq_lk, + memref_output_lq_dv, + ) + def attention_bf16(arg0, arg1, arg2, arg3, arg4): + c_launch_lq = ConstantOp(index_type, launch_lq_iters) + c_num_head_groups = ConstantOp(index_type, num_head_groups) + + # Non-causal: launch iterates Q blocks at host level (no BD chain limit) + # Causal: launch size 1, Q iteration inside herd (device-local q_block) + @launch( + operands=[arg0, arg1, arg2, arg4], sizes=[c_launch_lq, c_num_head_groups] + ) + def launch_body(arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12): + # arg5 = Q iteration index (0..launch_lq_iters-1), arg6 = head group + c0 = ConstantOp(index_type, 0) + c1 = ConstantOp(index_type, 1) + + # Compute actual head indices from head group + # head_base = arg6 * 2 (for head groups 0,1,2,3,4,5 -> heads 0-1, 2-3, 4-5, 6-7, 8-9, 10-11) + affine_map_head_base = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(num_heads_per_unroll), + ) + ], + ) + head_base = affine_apply(affine_map_head_base, [arg6]) + affine_map_add_one = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(0), AffineConstantExpr.get(1) + ) + ], + ) + head_1 = affine_apply(affine_map_add_one, [head_base]) + + # GQA: compute KV head indices from Q head indices + # kv_head = q_head // gqa_group_size + if gqa_group_size == 1: + # MHA: kv_head == q_head + kv_head_base = head_base + kv_head_1 = head_1 + else: + affine_map_kv_head = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_floor_div( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(gqa_group_size), + ) + ], + ) + kv_head_base = affine_apply(affine_map_kv_head, [head_base]) + kv_head_1 = affine_apply(affine_map_kv_head, [head_1]) + + # Affine map for Q tile partitioning within lqp chunk + affine_map_tileq = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(tile_size_q) + ) + ], + ) + # Affine map for launch offset: arg5 * lqp * dk + affine_map_launch_offset = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lqp * dk) + ) + ], + ) + # Affine map for Q head offset: head * lq * dk + launch_offset + affine_map_q_head_offset = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lq * dk) + ), + AffineSymbolExpr.get(1), + ) + ], + ) + # Affine map for K head offset: head * lk * dk + row_offset * dk + # K stored as [num_kv_heads, lk, dk] (row-major) + affine_map_head_row = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lk * dk) + ), + AffineExpr.get_mul( + AffineSymbolExpr.get(1), AffineConstantExpr.get(dk) + ), + ) + ], + ) + # Affine map for V head offset: head * lk * dv + affine_map_v_head_offset = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lk * dv) + ) + ], + ) + + # Combined Q/K/V/output DMA loop — one iteration per q_iter + # Must be a single loop so Q, K, V, and output are interleaved in + # the correct order matching the segment's consumption pattern. + c_device_lq_iters = ConstantOp(index_type, device_lq_iters) + for lq_it in range_(c0, c_device_lq_iters, c1): + # Combine launch Q index (arg5) + device Q index (lq_it) + # Non-causal: arg5 varies, lq_it=0. Causal: arg5=0, lq_it varies. + q_iter_global = arith.AddIOp(arg5, lq_it) + + # (A) Q: L3→L2 for this q_iter + par_1 = scf.ForallOp( + lower_bounds=[0], upper_bounds=[num_cascade_stages], steps=[1] + ) + with InsertionPoint(par_1.body): + tile_offset = affine_apply( + affine_map_tileq, [par_1.induction_variables[0]] + ) + launch_offset = affine_apply( + affine_map_launch_offset, [q_iter_global.result] + ) + # Head 0 in group (head_base) + q_head0_off = affine_apply( + affine_map_q_head_offset, [head_base, launch_offset] + ) + ChannelPut( + "L3ToL2Chan1", + arg9, + indices=[c0, par_1.induction_variables[0]], + offsets=[tile_offset, q_head0_off], + sizes=[tile_size_q, dk], + strides=[dk, 1], + ) + # Head 1 in group (head_base + 1) + q_head1_off = affine_apply( + affine_map_q_head_offset, [head_1, launch_offset] + ) + ChannelPut( + "L3ToL2Chan1", + arg9, + indices=[c1, par_1.induction_variables[0]], + offsets=[tile_offset, q_head1_off], + sizes=[tile_size_q, dk], + strides=[dk, 1], + ) + scf.InParallelOp() + + # (B) K: L3→L2 for this q_iter (same K data re-sent each iter) + for i in range(num_cascade_stages): + row_off = ConstantOp(index_type, i * chunks_per_stage * lkp) + k_head0_off = affine_apply( + affine_map_head_row, [kv_head_base, row_off] + ) + ChannelPut( + "L3ToL2Chan1", + arg10, + indices=[c0, i], + offsets=[0, 0, k_head0_off], + sizes=[chunks_per_stage, lkp, dk], + strides=[lkp * dk, dk, 1], + ) + k_head1_off = affine_apply( + affine_map_head_row, [kv_head_1, row_off] + ) + ChannelPut( + "L3ToL2Chan1", + arg10, + indices=[c1, i], + offsets=[0, 0, k_head1_off], + sizes=[chunks_per_stage, lkp, dk], + strides=[lkp * dk, dk, 1], + ) + + # (C) V: L3→L2 for this q_iter (same V data re-sent each iter) + for i in range(num_cascade_stages): + v_head0_off = affine_apply(affine_map_v_head_offset, [kv_head_base]) + ChannelPut( + "L3ToL2Chan2", + arg11, + indices=[c0, i], + offsets=[0, i * chunks_per_stage * lkp, v_head0_off], + sizes=[chunks_per_stage, lkp, dv], + strides=[lkp * dv, dv, 1], + ) + v_head1_off = affine_apply(affine_map_v_head_offset, [kv_head_1]) + ChannelPut( + "L3ToL2Chan2", + arg11, + indices=[c1, i], + offsets=[0, i * chunks_per_stage * lkp, v_head1_off], + sizes=[chunks_per_stage, lkp, dv], + strides=[lkp * dv, dv, 1], + ) + + # (D) Output: L2→L3 for this q_iter + launch_offset_out = affine_apply( + affine_map_launch_offset, [q_iter_global.result] + ) + out_head0_off = affine_apply( + affine_map_q_head_offset, [head_base, launch_offset_out] + ) + out_head1_off = affine_apply( + affine_map_q_head_offset, [head_1, launch_offset_out] + ) + ChannelGet( + "L2ToL3Chan1", + arg12, + indices=[c0], + offsets=[0, out_head0_off], + sizes=[lqp, dk], + strides=[dk, 1], + ) + ChannelGet( + "L2ToL3Chan1", + arg12, + indices=[c1], + offsets=[0, out_head1_off], + sizes=[lqp, dk], + strides=[dk, 1], + ) + + yield_([]) + + # Segment unrolls over 2 heads (hardware constraint) + c_num_heads_unroll = ConstantOp(index_type, num_heads_per_unroll) + c_dummy_size = ConstantOp(index_type, 1) + + # In causal mode, pass launch Q index through segment to herd + # for causal block index computation. After runtime loop tiling + # (runtime_loop_tiling_sizes=[1,1]), arg5 becomes a constant in + # each tiled iteration, so the RTP write in airrt-to-npu succeeds. + seg_operands = [] + + @segment( + name="attention_seg", + operands=seg_operands, + sizes=[c_num_heads_unroll, c_dummy_size], + ) + def segment_body(*seg_args): + head_idx, dummy_idx, head_size, dummy_size = seg_args[:4] + launch_q_idx = seg_args[4] if (causal and len(seg_args) > 4) else None + # L2 allocations + if enable_shared_buffers: + alloc = alloc_col1 = alloc_col2 = alloc_col3 = None + else: + alloc = AllocOp(memref_lqp_dk_l2, [], []) + alloc_col1 = AllocOp(memref_lqp_dk_l2, [], []) + alloc_col2 = AllocOp(memref_lqp_dk_l2, [], []) + alloc_col3 = AllocOp(memref_lqp_dk_l2, [], []) + alloc_2 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_21 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_22 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_23 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_3 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_31 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_32 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_33 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_5 = AllocOp(memref_output_lqp_dv_l2, [], []) + up = AllocOp(memref_lqp_l1, [], []) + sp = AllocOp(memref_lqp_l1, [], []) + Gp = AllocOp(memref_lqp_dv_l1, [], []) + alloc_6 = AllocOp(memref_lqp_dv_l1, [], []) + if enable_shared_buffers: + G_shared = AllocOp(memref_g_shared_l1, [], []) + QK_shared = AllocOp(memref_dv_lkp_l1, [], []) + else: + G_shared = None + QK_shared = None + # Local counter for causal block index tracking. + # Passed as memref operand (NOT scalar) → no RTP, no herd lock. + causal_counter = AllocOp(memref_2xi32_l1, [], []) if causal else None + + c_num_q_tiles = ConstantOp(index_type, num_q_tiles) + c_num_cascade = ConstantOp(index_type, num_cascade_stages) + c0_seg = ConstantOp(index_type, 0) + c1_seg = ConstantOp(index_type, 1) + c2_seg = ConstantOp(index_type, 2) + c3_seg = ConstantOp(index_type, 3) + + # Q/K/V/output DMA loop over lq_iters (Q iteration moved from launch to device) + q_l2_bufs = ( + [alloc_2, alloc_21, alloc_22, alloc_23] + if enable_shared_buffers + else [alloc, alloc_col1, alloc_col2, alloc_col3] + ) + q_chan = "L2ToL1Chan2" if enable_shared_buffers else "L2ToL1Chan1" + q_idx = lambda col: ( + [c0_seg, col] if enable_shared_buffers else [col, c0_seg] + ) + + c_device_lq_seg = ConstantOp(index_type, device_lq_iters) + for lq_it_seg in range_(c0_seg, c_device_lq_seg, c1_seg): + # (A) Q: L3→L2 gets for this q_iter's 4 tiles + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[0].result, indices=[head_idx, c0_seg] + ) + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[1].result, indices=[head_idx, c1_seg] + ) + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[2].result, indices=[head_idx, c2_seg] + ) + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[3].result, indices=[head_idx, c3_seg] + ) + + # (B) Q: L2→L1 puts for this q_iter's 4 tiles + ChannelPut( + q_chan, + q_l2_bufs[0].result, + indices=q_idx(c0_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_k, dk, 1], + ) + ChannelPut( + q_chan, + q_l2_bufs[1].result, + indices=q_idx(c1_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_k, dk, 1], + ) + ChannelPut( + q_chan, + q_l2_bufs[2].result, + indices=q_idx(c2_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_k, dk, 1], + ) + ChannelPut( + q_chan, + q_l2_bufs[3].result, + indices=q_idx(c3_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_k, dk, 1], + ) + + # (C) K/V streaming: L3→L2 + L2→L1 (inner loop) + for arg21 in range_(0, chunks_per_stage, 1): + # Channel gets for K and V - use head_idx + ChannelGet( + "L3ToL2Chan1", alloc_2.result, indices=[head_idx, c0_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_3.result, indices=[head_idx, c0_seg] + ) + ChannelGet( + "L3ToL2Chan1", alloc_21.result, indices=[head_idx, c1_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_31.result, indices=[head_idx, c1_seg] + ) + ChannelGet( + "L3ToL2Chan1", alloc_22.result, indices=[head_idx, c2_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_32.result, indices=[head_idx, c2_seg] + ) + ChannelGet( + "L3ToL2Chan1", alloc_23.result, indices=[head_idx, c3_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_33.result, indices=[head_idx, c3_seg] + ) + + # Channel puts for K matrix to L1 + ChannelPut( + "L2ToL1Chan2", + alloc_2.result, + indices=[c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=[lkp // mmul_n, dk // mmul_k, mmul_n, mmul_k], + strides=[mmul_n * dk, mmul_k, dk, 1], + ) + ChannelPut( + "L2ToL1Chan2", + alloc_21.result, + indices=[c0_seg, c1_seg], + offsets=[0, 0, 0, 0], + sizes=[lkp // mmul_n, dk // mmul_k, mmul_n, mmul_k], + strides=[mmul_n * dk, mmul_k, dk, 1], + ) + ChannelPut( + "L2ToL1Chan2", + alloc_22.result, + indices=[c0_seg, c2_seg], + offsets=[0, 0, 0, 0], + sizes=[lkp // mmul_n, dk // mmul_k, mmul_n, mmul_k], + strides=[mmul_n * dk, mmul_k, dk, 1], + ) + ChannelPut( + "L2ToL1Chan2", + alloc_23.result, + indices=[c0_seg, c3_seg], + offsets=[0, 0, 0, 0], + sizes=[lkp // mmul_n, dk // mmul_k, mmul_n, mmul_k], + strides=[mmul_n * dk, mmul_k, dk, 1], + ) + + # Channel puts for V matrix to L1 + ChannelPut( + "L2ToL1Chan3", + alloc_3.result, + indices=[c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_n, dv, 1], + ) + ChannelPut( + "L2ToL1Chan3", + alloc_31.result, + indices=[c0_seg, c1_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_n, dv, 1], + ) + ChannelPut( + "L2ToL1Chan3", + alloc_32.result, + indices=[c0_seg, c2_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_n, dv, 1], + ) + ChannelPut( + "L2ToL1Chan3", + alloc_33.result, + indices=[c0_seg, c3_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_n, dv, 1], + ) + + yield_([]) + + # (D) Output: L1→L2 gather for this q_iter + affine_map_tileq_seg = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_size_q), + ) + ], + ) + par_final = scf.ForallOp( + lower_bounds=[0], upper_bounds=[c_num_q_tiles], steps=[1] + ) + with InsertionPoint(par_final.body): + apply_final = affine_apply( + affine_map_tileq_seg, [par_final.induction_variables[0]] + ) + ChannelGet( + "L1ToL2Chan1", + alloc_5.result, + indices=[par_final.induction_variables[0], 0], + offsets=[apply_final, 0], + sizes=[tile_size_q, dv], + strides=[dv, 1], + ) + scf.InParallelOp() + + # (E) Output: L2→L3 transfer for this q_iter + ChannelPut("L2ToL3Chan1", alloc_5.result, indices=[head_idx]) + + yield_([]) + + # Unified herd: init + compute loop + cascade merge + output + unified_operands = ( + [alloc_6, up, sp, Gp, G_shared, QK_shared] + if enable_shared_buffers + else [alloc_6, up, sp, Gp] + ) + # Causal: pass counter as memref operand (no RTP/lock) + if causal: + unified_operands = unified_operands + [causal_counter] + + @herd( + name="herd_0", + sizes=[c_num_q_tiles, c_num_cascade], + operands=unified_operands, + link_with="attn_pkt.o", + ) + def unified_herd_body(*args): + arg22, arg23, arg24, arg25 = args[0], args[1], args[2], args[3] + if enable_shared_buffers: + arg26, arg27, arg28, arg29, arg30, arg31 = args[4:10] + counter_buf = args[10] if causal else None + else: + arg26, arg27, arg28, arg29 = args[4:8] + arg30 = arg31 = None + counter_buf = args[8] if causal else None + + if causal: + # Local counter. With lkp==dk (shared + # buffers), DMAs are infinite loops → no PDI reset + # → core loops continuously → counter persists. + # counter[0] = q_block_global + # counter[1] = boot flag (0=first, 1=initialized) + c0_ctr = ConstantOp(index_type, 0) + c1_ctr = ConstantOp(index_type, 1) + boot_flag = load(counter_buf, [c1_ctr]) + c0_i32_ctr = ConstantOp(i32, 0) + is_first = arith.CmpIOp( + arith.CmpIPredicate.eq, boot_flag, c0_i32_ctr + ) + if_first = scf.IfOp(is_first) + with InsertionPoint(if_first.then_block): + q_init = arith.IndexCastOp(i32, arg22) + store(q_init, counter_buf, [c0_ctr]) + c1_i32_f = ConstantOp(i32, 1) + store(c1_i32_f, counter_buf, [c1_ctr]) + scf.YieldOp([]) + + # === OUTER Q ITERATION LOOP (device-side) === + c_lq_iters_herd = ConstantOp(index_type, device_lq_iters) + c0_q = ConstantOp(index_type, 0) + c1_q = ConstantOp(index_type, 1) + + for q_iter in range_(c0_q, c_lq_iters_herd, c1_q): + + # === INIT PHASE === + if enable_shared_buffers: + ChannelGet("L2ToL1Chan2", arg31, indices=[arg22, arg23]) + CallOp([], "copy_tile", [arg31, arg26]) + else: + ChannelGet("L2ToL1Chan1", arg26, indices=[arg22, arg23]) + CallOp([], "zero_fill_gp_bf16", [arg29]) + CallOp([], "zero_fill_sp_bf16", [arg28]) + CallOp([], "neg_inf_fill_up_bf16", [arg27]) + + # === COMPUTE LOOP (on-device) === + c_chunks = ConstantOp(index_type, chunks_per_stage) + c0_loop = ConstantOp(index_type, 0) + c1_loop = ConstantOp(index_type, 1) + + for chunk_idx in range_(c0_loop, c_chunks, c1_loop): + if enable_shared_buffers: + G_l1 = CollapseShapeOp( + memref_lqp_lkp_l1, arg30, [[0, 1]] + ) + else: + G_alloc = AllocOp(memref_g_shared_l1, [], []) + G_l1 = CollapseShapeOp( + memref_lqp_lkp_l1, G_alloc.result, [[0, 1]] + ) + + CallOp([], "zero_fill_g_bf16", [G_l1]) + + if enable_shared_buffers: + ChannelGet("L2ToL1Chan2", arg31, indices=[arg22, arg23]) + CallOp([], "matmul_a_b_bf16", [arg26, arg31, G_l1]) + else: + QK_alloc = AllocOp(memref_dv_lkp_l1, [], []) + ChannelGet( + "L2ToL1Chan2", + QK_alloc.result, + indices=[arg22, arg23], + ) + CallOp( + [], + "matmul_a_b_bf16", + [arg26, QK_alloc.result, G_l1], + ) + + alloc_57 = AllocOp(memref_dv_lkp_l1, [], []) + ChannelGet( + "L2ToL1Chan3", alloc_57.result, indices=[arg22, arg23] + ) + + if causal: + # Local counter gives q_block_global. + # No RTP/herd lock — counter loaded from + # local L1 buffer. + c_cps = ConstantOp(index_type, chunks_per_stage) + kv_block = arith.AddIOp( + arith.MulIOp(arg23, c_cps).result, chunk_idx + ) + kv_i32 = arith.IndexCastOp(i32, kv_block.result) + c0_ctr_use = ConstantOp(index_type, 0) + q_i32 = load(counter_buf, [c0_ctr_use]) + CallOp([], "apply_causal_mask", [G_l1, q_i32, kv_i32]) + + c0_i32 = ConstantOp(i32, 0) + s_l1 = AllocOp(memref_lqp_l1, [], []) + r_l1 = AllocOp(memref_lqp_l1, [], []) + + # True fused softmax: max+exp+sum with f32 intermediates + CallOp( + [], + "fused_softmax", + [G_l1, arg27, s_l1.result, r_l1.result], + ) + CallOp([], "mul_r_gp", [r_l1.result, arg29]) + CallOp( + [], + "matmul_g_b_bf16", + [G_l1, alloc_57.result, arg29], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_l1.result, s_l1.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32, s_l1.result, arg28], + ) + + DeallocOp(s_l1) + DeallocOp(r_l1) + else: + c0_i32 = ConstantOp(i32, 0) + s_l1 = AllocOp(memref_lqp_l1, [], []) + r_l1 = AllocOp(memref_lqp_l1, [], []) + + # True fused softmax: max+exp+sum with f32 intermediates + CallOp( + [], + "fused_softmax", + [G_l1, arg27, s_l1.result, r_l1.result], + ) + CallOp([], "mul_r_gp", [r_l1.result, arg29]) + CallOp( + [], + "matmul_g_b_bf16", + [G_l1, alloc_57.result, arg29], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_l1.result, s_l1.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32, s_l1.result, arg28], + ) + + DeallocOp(s_l1) + DeallocOp(r_l1) + + DeallocOp(alloc_57) + + if not enable_shared_buffers: + DeallocOp(QK_alloc) + DeallocOp(G_alloc) + yield_([]) + + # === CASCADE MERGE === + c1_h = ConstantOp(index_type, 1) + r_l1_c = AllocOp(memref_lqp_l1, [], []) + + def get_gp_cascade(): + if enable_shared_buffers: + return arg30 + else: + return AllocOp(memref_lqp_dv_l1, [], []).result + + # affine.if for last cascade stage + affine_set_last = IntegerSet.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(-num_cascade_stages + 1), + ), + AffineSymbolExpr.get(0), + AffineExpr.get_add( + AffineConstantExpr.get(num_q_tiles - 1), + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(-1), + ), + ), + ], + [True, False, False], + ) + affine_if_last = affine.AffineIfOp( + affine_set_last, cond_operands=[arg22, arg23], has_else=True + ) + with InsertionPoint(affine_if_last.then_block): + subi = arith.SubIOp(arg23, c1_h) + ChannelPut("cascade", arg29, indices=[arg22, subi]) + ChannelPut("cascade", arg27, indices=[arg22, subi]) + ChannelPut("cascade", arg28, indices=[arg22, subi]) + affine.AffineYieldOp([]) + + with InsertionPoint(affine_if_last.else_block): + affine_set_middle = IntegerSet.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(-1), + ), + AffineExpr.get_add( + AffineConstantExpr.get(num_cascade_stages - 2), + AffineExpr.get_mul( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(-1), + ), + ), + AffineSymbolExpr.get(0), + AffineExpr.get_add( + AffineConstantExpr.get(num_q_tiles - 1), + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(-1), + ), + ), + ], + [False, False, False, False], + ) + affine_if_middle = affine.AffineIfOp( + affine_set_middle, + cond_operands=[arg22, arg23], + has_else=True, + ) + with InsertionPoint(affine_if_middle.then_block): + Gp_cascade = get_gp_cascade() + up_cascade = AllocOp(memref_lqp_l1, [], []) + sp_cascade = AllocOp(memref_lqp_l1, [], []) + ChannelGet( + "cascade", Gp_cascade, indices=[arg22, arg23] + ) + ChannelGet( + "cascade", up_cascade.result, indices=[arg22, arg23] + ) + ChannelGet( + "cascade", sp_cascade.result, indices=[arg22, arg23] + ) + up_B_saved = AllocOp(memref_lqp_l1, [], []) + c0_i32_m = ConstantOp(i32, 0) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_m, arg27, up_B_saved.result], + ) + CallOp( + [], "maximum_up_u_bf16", [up_cascade.result, arg27] + ) + CallOp( + [], + "exp_up_minus_u", + [up_cascade.result, arg27, r_l1_c.result], + ) + r_B = AllocOp(memref_lqp_l1, [], []) + CallOp( + [], + "exp_up_minus_u", + [up_B_saved.result, arg27, r_B.result], + ) + CallOp([], "mul_r_gp", [r_l1_c.result, Gp_cascade]) + CallOp([], "mul_r_gp", [r_B.result, arg29]) + CallOp([], "add_gp_g", [arg29, Gp_cascade]) + sp_temp = AllocOp(memref_lqp_l1, [], []) + CallOp([], "zero_fill_sp_bf16", [sp_temp.result]) + CallOp( + [], + "accum_sp_r_s", + [sp_cascade.result, r_l1_c.result, sp_temp.result], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_B.result, sp_temp.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_m, sp_temp.result, sp_cascade.result], + ) + subi2 = arith.SubIOp(arg23, c1_h) + ChannelPut( + "cascade", Gp_cascade, indices=[arg22, subi2] + ) + ChannelPut("cascade", arg27, indices=[arg22, subi2]) + ChannelPut( + "cascade", sp_cascade.result, indices=[arg22, subi2] + ) + DeallocOp(up_B_saved) + DeallocOp(r_B) + DeallocOp(sp_temp) + affine.AffineYieldOp([]) + + with InsertionPoint(affine_if_middle.else_block): + Gp_cascade2 = get_gp_cascade() + up_cascade2 = AllocOp(memref_lqp_l1, [], []) + sp_cascade2 = AllocOp(memref_lqp_l1, [], []) + ChannelGet( + "cascade", Gp_cascade2, indices=[arg22, arg23] + ) + ChannelGet( + "cascade", + up_cascade2.result, + indices=[arg22, arg23], + ) + ChannelGet( + "cascade", + sp_cascade2.result, + indices=[arg22, arg23], + ) + up_B_saved2 = AllocOp(memref_lqp_l1, [], []) + c0_i32_f = ConstantOp(i32, 0) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_f, arg27, up_B_saved2.result], + ) + CallOp( + [], "maximum_up_u_bf16", [up_cascade2.result, arg27] + ) + CallOp( + [], + "exp_up_minus_u", + [up_cascade2.result, arg27, r_l1_c.result], + ) + r_B2 = AllocOp(memref_lqp_l1, [], []) + CallOp( + [], + "exp_up_minus_u", + [up_B_saved2.result, arg27, r_B2.result], + ) + CallOp([], "mul_r_gp", [r_l1_c.result, Gp_cascade2]) + CallOp([], "mul_r_gp", [r_B2.result, arg29]) + CallOp([], "add_gp_g", [arg29, Gp_cascade2]) + sp_temp2 = AllocOp(memref_lqp_l1, [], []) + CallOp([], "zero_fill_sp_bf16", [sp_temp2.result]) + CallOp( + [], + "accum_sp_r_s", + [ + sp_cascade2.result, + r_l1_c.result, + sp_temp2.result, + ], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_B2.result, sp_temp2.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_f, sp_temp2.result, sp_cascade2.result], + ) + CallOp( + [], "div_gp_sp", [sp_cascade2.result, Gp_cascade2] + ) + DeallocOp(up_B_saved2) + DeallocOp(r_B2) + DeallocOp(sp_temp2) + ChannelPut( + "L1ToL2Chan1", + Gp_cascade2, + indices=[arg22, 0], + offsets=[0, 0, 0, 0], + sizes=[ + tile_size_q // mmul_n, + mmul_m, + dv // mmul_m, + mmul_n, + ], + strides=[ + mmul_m * mmul_n, + mmul_n, + tile_size_q * mmul_n, + 1, + ], + ) + affine.AffineYieldOp([]) + affine.AffineYieldOp([]) + + # Increment q_block counter for next launch iteration + if causal: + c0_ci = ConstantOp(index_type, 0) + c2_ci = ConstantOp(index_type, 2) + c1_i32_ci = ConstantOp(i32, 1) + # Increment head counter + head_cur = load(counter_buf, [c2_ci]) + head_next = arith.AddIOp(head_cur, c1_i32_ci) + total_heads_i32 = ConstantOp(i32, num_head_groups) + wrapped = arith.CmpIOp( + arith.CmpIPredicate.sge, + head_next, + total_heads_i32, + ) + if_wrap = scf.IfOp(wrapped) + with InsertionPoint(if_wrap.then_block): + # All heads done: increment q_block, reset head + q_cur = load(counter_buf, [c0_ci]) + c_nqt_i32 = ConstantOp(i32, num_q_tiles) + q_next = arith.AddIOp(q_cur, c_nqt_i32) + store(q_next, counter_buf, [c0_ci]) + c0_i32_ci = ConstantOp(i32, 0) + store(c0_i32_ci, counter_buf, [c2_ci]) + scf.YieldOp([]) + if_wrap_else = scf.IfOp( + arith.CmpIOp( + arith.CmpIPredicate.slt, + head_next, + total_heads_i32, + ) + ) + with InsertionPoint(if_wrap_else.then_block): + store(head_next, counter_buf, [c2_ci]) + scf.YieldOp([]) + + yield_([]) # end of q_iter loop + + # Output channel gets are inside the combined Q/K/V/output loop above + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="attn.py") + parser.add_argument("-p", "--print-module-only", action="store_true") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument( + "--mlir-file", + type=str, + default=None, + help="Path to external MLIR file to compile (instead of generating)", + ) + parser.add_argument( + "--lk", type=int, default=12288, help="Total sequence length for K/V matrices" + ) + parser.add_argument( + "--lkp", type=int, default=96, help="Chunk size for K/V processing" + ) + parser.add_argument( + "--lq", type=int, default=512, help="Total sequence length for Q matrix" + ) + parser.add_argument( + "--lqp", + type=int, + default=128, + help="Chunk size for Q processing per launch iteration", + ) + parser.add_argument("--dk", type=int, default=64, help="Key dimension") + parser.add_argument("--dv", type=int, default=64, help="Value dimension") + parser.add_argument( + "--num-heads", type=int, default=12, help="Number of Q attention heads" + ) + parser.add_argument( + "--num-kv-heads", + type=int, + default=None, + help="Number of K/V heads (default: num_heads for MHA, set < num_heads for GQA)", + ) + parser.add_argument( + "--compile-mode", + type=str, + default="run", + choices=["run", "compile"], + help="Compilation mode: run (default, compile + test), compile (generate binary only)", + ) + parser.add_argument( + "--causal", + action="store_true", + help="Enable causal masking (autoregressive attention)", + ) + parser.add_argument( + "--val-range", + type=float, + default=3.0, + help="Input value range for random test data (default: 3.0)", + ) + args = parser.parse_args() + + lk, lkp, lq, lqp, dk, dv = args.lk, args.lkp, args.lq, args.lqp, args.dk, args.dv + causal = args.causal + num_heads = args.num_heads + num_kv_heads = args.num_kv_heads if args.num_kv_heads is not None else num_heads + + if num_kv_heads <= 0: + raise ValueError(f"num_kv_heads must be positive, got {num_kv_heads}") + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + ) + + if args.mlir_file: + with open(args.mlir_file, "r") as f: + mlir_source = f.read() + with Context() as ctx, Location.unknown(): + registry = DialectRegistry() + air.dialects.air.register_dialect(registry) + ctx.append_dialect_registry(registry) + ctx.load_all_available_dialects() + mlir_module = Module.parse(mlir_source) + print(f"Loaded MLIR module from: {args.mlir_file}") + else: + mlir_module = build_module( + lk=lk, + lkp=lkp, + lq=lq, + lqp=lqp, + dk=dk, + dv=dv, + num_q_tiles=4, + num_cascade_stages=4, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + causal=causal, + ) + + if args.print_module_only: + print(mlir_module) + exit(0) + + from air.backend.xrt_runner import XRTRunner, type_mapper + from air.backend.xrt import XRTBackend + from air.extras import types as extrasT + from ml_dtypes import bfloat16 + + INPUT_DATATYPE = OUTPUT_DATATYPE = bfloat16 + VM_ACC_DATATYPE = np.float32 + + gqa_group_size = num_heads // num_kv_heads + + rng = np.random.default_rng(42) + val_range = args.val_range + input_q = rng.uniform(0, val_range, (num_heads, lq, dk)).astype(INPUT_DATATYPE) + input_k = rng.uniform(0, val_range, (num_kv_heads, lk, dk)).astype(INPUT_DATATYPE) + input_v = rng.uniform(0, val_range, (num_kv_heads, lk, dv)).astype(INPUT_DATATYPE) + input_m = np.zeros((num_heads, lq, lk), dtype=INPUT_DATATYPE) + + inv_sqrt_dk = 1.0 / sqrt(dk) + + def sdpa_golden(Q, K, V, scale, causal_mask=False): + """Standard scaled dot-product attention in f32.""" + scores = (Q.astype(np.float32) @ K.astype(np.float32).T) * scale + if causal_mask: + mask = np.triu(np.ones(scores.shape, dtype=bool), k=1) + scores = np.where(mask, -1e9, scores) + m = np.max(scores, axis=-1, keepdims=True) + exp_s = np.exp(scores - m) + P = exp_s / np.sum(exp_s, axis=-1, keepdims=True) + return (P @ V.astype(np.float32)).astype(OUTPUT_DATATYPE) + + sdpa_output = np.zeros((num_heads, lq, dv), dtype=OUTPUT_DATATYPE) + for h in range(num_heads): + kv_h = h // gqa_group_size + sdpa_output[h] = sdpa_golden( + input_q[h], + input_k[kv_h], + input_v[kv_h], + inv_sqrt_dk, + causal_mask=causal, + ) + + enable_shared_buffers_main = lkp == dk + # Causal mode requires while-true loop: the herd RTP mechanism needs the + # core to loop back and re-acquire the herd lock for each launch iteration. + # Without the loop, the core exits after one iteration and subsequent + # RTP writes / lock releases go to a dead core. + omit_loop = False if causal else not enable_shared_buffers_main + runner = XRTRunner( + omit_while_true_loop=omit_loop, + omit_pingpong="all", + verbose=args.verbose, + runtime_loop_tiling_sizes=[1, 1], + output_format="elf", + instance_name="attention_bf16", + ) + + if args.compile_mode == "run": + exit( + runner.run_test( + mlir_module, + inputs=[input_q, input_k, input_v, input_m], + expected_outputs=[sdpa_output], + atol=0.15, + rtol=0.04, + max_mismatch_percentage=2, + ) + ) + elif args.compile_mode == "compile": + backend = XRTBackend( + omit_while_true_loop=omit_loop, + omit_pingpong="all", + verbose=args.verbose, + runtime_loop_tiling_sizes=[1, 1], + output_format="elf", + instance_name="attention_bf16", + ) + module_function = backend.compile(mlir_module) + print(f"Compilation complete. Generated elf binary") diff --git a/programming_examples/flash_attention/packet_switched/attn_npu1.cc b/programming_examples/flash_attention/packet_switched/attn_npu1.cc new file mode 100644 index 000000000..abeaaca9d --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/attn_npu1.cc @@ -0,0 +1,1012 @@ +//===- attn_npu1.cc - Flash attention kernels for NPU1 (AIE2) ---*- C++ -*-===// +// +// SPDX-License-Identifier: MIT +// Copyright (C) 2025, Advanced Micro Devices, Inc. +// +// NPU1 (AIE2) variant of kernel_fusion_based flash attention. +// Key differences from NPU2 (attn_npu2.cc): +// - mmul<4,8,4> instead of mmul<8,8,8> +// - LUT-based exp instead of aie::exp2 +// - Column-major 4x4 block tiling instead of 8x8 +// - aie::div instead of aie::inv +// - scale_g_bf16: explicit 1/sqrt(dk) scaling after matmul +// +//===----------------------------------------------------------------------===// + +#define NOCPP + +#include +#include +#include +#include + +#define REL_WRITE 0 +#define REL_READ 1 + +#include + +#include "lut_based_ops.h" +#include "zero.cc" + +// Default values if not provided by Makefile +#ifndef lqp +#define lqp 32 +#endif + +#ifndef lkp +#define lkp 96 +#endif + +#ifndef dk +#define dk 64 +#endif + +#ifndef dv +#define dv 64 +#endif + +#ifndef dv_full +#define dv_full dv +#endif + +#ifndef dk_full +#define dk_full dk +#endif + +// ============================================================================ +// Matmul template: 4x4 expansion with transpose_b control for AIE2 mmul<4,8,4> +// ============================================================================ + +// Column-major B matmul with compile-time transpose control. +// transpose_b: true = apply aie::transpose before mac (K DMA: inner [n_in, +// k_in]) +// false = load B as-is, hardware mul_4x8_4x8T transposes (V DMA: +// inner [k_in, n_in]) +// A and C are always column-major tiled. +template +static inline void matmul_vectorized_4x4(const T_in *__restrict pA, + const T_in *__restrict pB, + T_out *__restrict pC) { + + using MMUL = aie::mmul; + + event0(); + + for (unsigned z = 0; z < rowA; z += 4) + chess_prepare_for_pipelining chess_loop_range(2, ) { + T_out *__restrict pC1 = pC + (z)*MMUL::size_C; + T_out *__restrict pC2 = pC + ((z + 1)) * MMUL::size_C; + T_out *__restrict pC3 = pC + ((z + 2)) * MMUL::size_C; + T_out *__restrict pC4 = pC + ((z + 3)) * MMUL::size_C; + + for (unsigned j = 0; j < colB; j += 4) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + const T_in *__restrict pA1 = pA + (z)*MMUL::size_A; + const T_in *__restrict pA2 = pA + ((z + 1)) * MMUL::size_A; + const T_in *__restrict pA3 = pA + ((z + 2)) * MMUL::size_A; + const T_in *__restrict pA4 = pA + ((z + 3)) * MMUL::size_A; + + const T_in *__restrict pB1 = pB + (j)*colA * MMUL::size_B; + const T_in *__restrict pB2 = pB + ((j + 1)) * colA * MMUL::size_B; + const T_in *__restrict pB3 = pB + ((j + 2)) * colA * MMUL::size_B; + const T_in *__restrict pB4 = pB + ((j + 3)) * colA * MMUL::size_B; + + aie::vector A0 = aie::load_v(pA1); + pA1 += rowA * MMUL::size_A; + aie::vector A1 = aie::load_v(pA2); + pA2 += rowA * MMUL::size_A; + aie::vector A2 = aie::load_v(pA3); + pA3 += rowA * MMUL::size_A; + aie::vector A3 = aie::load_v(pA4); + pA4 += rowA * MMUL::size_A; + + aie::vector B0, B1, B2, B3; + if constexpr (transpose_b) { + // K DMA k-major block layout: block (n=j, k=i) at i*colB+j. + // Sub-tile elements are [n_in, k_in], transpose to [k_in, n_in]. + const T_in *__restrict pBk0 = pB + (0 * colB + j) * MMUL::size_B; + const T_in *__restrict pBk1 = + pB + (0 * colB + (j + 1)) * MMUL::size_B; + const T_in *__restrict pBk2 = + pB + (0 * colB + (j + 2)) * MMUL::size_B; + const T_in *__restrict pBk3 = + pB + (0 * colB + (j + 3)) * MMUL::size_B; + B0 = aie::transpose(aie::load_v(pBk0), t, s); + B1 = aie::transpose(aie::load_v(pBk1), t, s); + B2 = aie::transpose(aie::load_v(pBk2), t, s); + B3 = aie::transpose(aie::load_v(pBk3), t, s); + } else { + B0 = aie::load_v(pB1); + B1 = aie::load_v(pB2); + B2 = aie::load_v(pB3); + B3 = aie::load_v(pB4); + } + pB1 += MMUL::size_B; + pB2 += MMUL::size_B; + pB3 += MMUL::size_B; + pB4 += MMUL::size_B; + + aie::vector acc_C00 = + aie::load_v(pC1); + aie::vector acc_C01 = + aie::load_v(pC1 + MMUL::size_C * rowA); + aie::vector acc_C02 = + aie::load_v(pC1 + 2 * MMUL::size_C * rowA); + aie::vector acc_C03 = + aie::load_v(pC1 + 3 * MMUL::size_C * rowA); + + aie::vector acc_C10 = + aie::load_v(pC2); + aie::vector acc_C11 = + aie::load_v(pC2 + MMUL::size_C * rowA); + aie::vector acc_C12 = + aie::load_v(pC2 + 2 * MMUL::size_C * rowA); + aie::vector acc_C13 = + aie::load_v(pC2 + 3 * MMUL::size_C * rowA); + + aie::vector acc_C20 = + aie::load_v(pC3); + aie::vector acc_C21 = + aie::load_v(pC3 + MMUL::size_C * rowA); + aie::vector acc_C22 = + aie::load_v(pC3 + 2 * MMUL::size_C * rowA); + aie::vector acc_C23 = + aie::load_v(pC3 + 3 * MMUL::size_C * rowA); + + aie::vector acc_C30 = + aie::load_v(pC4); + aie::vector acc_C31 = + aie::load_v(pC4 + MMUL::size_C * rowA); + aie::vector acc_C32 = + aie::load_v(pC4 + 2 * MMUL::size_C * rowA); + aie::vector acc_C33 = + aie::load_v(pC4 + 3 * MMUL::size_C * rowA); + + MMUL C00(acc_C00); + MMUL C01(acc_C01); + MMUL C02(acc_C02); + MMUL C03(acc_C03); + + MMUL C10(acc_C10); + MMUL C11(acc_C11); + MMUL C12(acc_C12); + MMUL C13(acc_C13); + + MMUL C20(acc_C20); + MMUL C21(acc_C21); + MMUL C22(acc_C22); + MMUL C23(acc_C23); + + MMUL C30(acc_C30); + MMUL C31(acc_C31); + MMUL C32(acc_C32); + MMUL C33(acc_C33); + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + + C02.mac(A0, B2); + C03.mac(A0, B3); + C12.mac(A1, B2); + C13.mac(A1, B3); + + C20.mac(A2, B0); + C21.mac(A2, B1); + C30.mac(A3, B0); + C31.mac(A3, B1); + + C22.mac(A2, B2); + C23.mac(A2, B3); + C32.mac(A3, B2); + C33.mac(A3, B3); + + for (unsigned i = 1; i < colA; ++i) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + A0 = aie::load_v(pA1); + pA1 += rowA * MMUL::size_A; + A1 = aie::load_v(pA2); + pA2 += rowA * MMUL::size_A; + A2 = aie::load_v(pA3); + pA3 += rowA * MMUL::size_A; + A3 = aie::load_v(pA4); + pA4 += rowA * MMUL::size_A; + + if constexpr (transpose_b) { + const T_in *__restrict pBk0 = + pB + (i * colB + j) * MMUL::size_B; + const T_in *__restrict pBk1 = + pB + (i * colB + (j + 1)) * MMUL::size_B; + const T_in *__restrict pBk2 = + pB + (i * colB + (j + 2)) * MMUL::size_B; + const T_in *__restrict pBk3 = + pB + (i * colB + (j + 3)) * MMUL::size_B; + B0 = aie::transpose(aie::load_v(pBk0), t, s); + B1 = aie::transpose(aie::load_v(pBk1), t, s); + B2 = aie::transpose(aie::load_v(pBk2), t, s); + B3 = aie::transpose(aie::load_v(pBk3), t, s); + } else { + B0 = aie::load_v(pB1); + B1 = aie::load_v(pB2); + B2 = aie::load_v(pB3); + B3 = aie::load_v(pB4); + } + pB1 += MMUL::size_B; + pB2 += MMUL::size_B; + pB3 += MMUL::size_B; + pB4 += MMUL::size_B; + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + + C02.mac(A0, B2); + C03.mac(A0, B3); + C12.mac(A1, B2); + C13.mac(A1, B3); + + C20.mac(A2, B0); + C21.mac(A2, B1); + C30.mac(A3, B0); + C31.mac(A3, B1); + + C22.mac(A2, B2); + C23.mac(A2, B3); + C32.mac(A3, B2); + C33.mac(A3, B3); + } + + aie::store_v(pC1, C00.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C01.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C02.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C03.template to_vector()); + pC1 += MMUL::size_C * rowA; + + aie::store_v(pC2, C10.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C11.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C12.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C13.template to_vector()); + pC2 += MMUL::size_C * rowA; + + aie::store_v(pC3, C20.template to_vector()); + pC3 += MMUL::size_C * rowA; + aie::store_v(pC3, C21.template to_vector()); + pC3 += MMUL::size_C * rowA; + aie::store_v(pC3, C22.template to_vector()); + pC3 += MMUL::size_C * rowA; + aie::store_v(pC3, C23.template to_vector()); + pC3 += MMUL::size_C * rowA; + + aie::store_v(pC4, C30.template to_vector()); + pC4 += MMUL::size_C * rowA; + aie::store_v(pC4, C31.template to_vector()); + pC4 += MMUL::size_C * rowA; + aie::store_v(pC4, C32.template to_vector()); + pC4 += MMUL::size_C * rowA; + aie::store_v(pC4, C33.template to_vector()); + pC4 += MMUL::size_C * rowA; + } + } + + event1(); +} + +// bf16 MatMul kernel with bf16 outputs for AIE2 (4x8x4). +// transpose_b: controls whether B blocks are software-transposed before mac. +template +static inline void +matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA, + const bfloat16 *__restrict pB, + bfloat16 *__restrict pC) { + constexpr int r = 4; + constexpr int s = 8; + constexpr int t = 4; + static_assert(m % (4 * r) == 0); // 'm' dimension + static_assert(k % s == 0); // 'k' dimension + static_assert(n % (4 * t) == 0); // 'n' dimension + + return matmul_vectorized_4x4(pA, pB, pC); +} + +// ============================================================================ +// LUT-based exponential for AIE2 (no native exp2) +// ============================================================================ + +alignas(aie::vector_decl_align) extern int16 exp_ilut_ab[512]; +alignas(aie::vector_decl_align) extern int16 exp_ilut_cd[512]; +alignas(aie::vector_decl_align) extern int16 exp_flut_ab[512]; +alignas(aie::vector_decl_align) extern int16 exp_flut_cd[512]; + +__attribute__((always_inline)) v16accfloat getExpBf16(v16bfloat16 x) { + bfloat16 __aie_dm_resource_a *ilut_ab = + (bfloat16 __aie_dm_resource_a *)exp_ilut_ab; + bfloat16 __aie_dm_resource_b *ilut_cd = + (bfloat16 __aie_dm_resource_b *)exp_ilut_cd; + bfloat16 __aie_dm_resource_a *flut_ab = + (bfloat16 __aie_dm_resource_a *)exp_flut_ab; + bfloat16 __aie_dm_resource_b *flut_cd = + (bfloat16 __aie_dm_resource_b *)exp_flut_cd; + + using lut_type = aie::lut<4, bfloat16, bfloat16>; + const int LUT_elems = 256; + const int step_i = 8; + const int step_f = 0; + + lut_type lut_i(LUT_elems, ilut_ab, ilut_cd); + lut_type lut_f(LUT_elems, flut_ab, flut_cd); + aie::parallel_lookup + lookup_i(lut_i, step_i); + aie::parallel_lookup + lookup_f(lut_f, step_f); + + aie::vector I_val_vec, F_val_vec; + aie::accum exp_val; + aie::vector input_bf16 = x; + + // position of output decimal point = 8, making input become 8 bits, and for + // LUT_elems = 256 lookup. + aie::vector input0 = v32int16(bfloat16_to_int(input_bf16, 8)); + aie::vector input = aie::filter_even(input0); + + I_val_vec = lookup_i.fetch(input.cast_to()); + F_val_vec = lookup_f.fetch(input.cast_to()); + exp_val = aie::mul(I_val_vec, F_val_vec); + return v16accfloat(exp_val); +} + +// ============================================================================ +// Scaling constant for 1/sqrt(dk_full) +// ============================================================================ +#include + +static const double inv_sqrt_dk_val = 1.0 / sqrt((double)dk_full); + +#define inv_sqrt_dk inv_sqrt_dk_val + +// ============================================================================ +// Kernel functions +// ============================================================================ + +extern "C" { + +// Copy tile_size_q x dk elements from src to dst (single-pass vector copy) +void copy_tile(bfloat16 *src, bfloat16 *dst) { + constexpr int VecLen = 16; + constexpr int num_elems = lqp * dk; + bfloat16 *__restrict ps = src; + bfloat16 *__restrict pd = dst; + for (unsigned j = 0; j < num_elems / VecLen; j++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = aie::load_v(ps); + aie::store_v(pd, v); + ps += VecLen; + pd += VecLen; + } +} + +void matmul_a_b_bf16(bfloat16 *a_in, bfloat16 *b_in, bfloat16 *out) { + // Buffer shapes: + // A: [lqp, dk] (Q tile, column-major 4x4 tiled) + // B: [lkp, dk] (K chunk, transpose per block) + // Out: [lqp, lkp] (G matrix, column-major 4x4 tiled) + matmul_vectorized_4x8x4_bf16_bf16(a_in, b_in, out); +} + +void matmul_g_b_bf16(bfloat16 *g_in, bfloat16 *b_in, bfloat16 *out) { + // Buffer shapes: + // G: [lqp, lkp] (attention scores, column-major 4x4 tiled) + // B: [lkp, dv] (V chunk, no software transpose) + // Out: [lqp, dv] (attention output, column-major 4x4 tiled) + // + // G is in 4x4 column-major block layout (from QK matmul C output): + // Block [rb, cb] at g_in + rb * size_C + cb * rowA_C * size_C + // where size_C = r*t = 16, rowA_C = lqp/r = lqp/4. + // But mmul<4,8,4> needs A in 4x8 block format (size_A = r*s = 32). + // We load two adjacent column-blocks and interleave them into a 4x8 sub-tile. + // + // matmul: G[lqp, lkp] x V[lkp, dv] -> Out[lqp, dv] + constexpr int r = 4; + constexpr int s = 8; + constexpr int t = 4; + constexpr unsigned rowA = lqp / r; // number of row-blocks of A/C + constexpr unsigned colA = lkp / s; // number of k-blocks (A is 4x8) + constexpr unsigned colB = dv / t; // number of n-blocks of B/C + using MMUL = aie::mmul; + + // 4x4 C-block layout parameters + constexpr unsigned size_C_blk = r * t; // 16 elements per 4x4 block + constexpr unsigned col_block_stride = + rowA * size_C_blk; // stride between column-blocks = 16*16 = 256 + + event0(); + + for (unsigned z = 0; z < rowA; z += 4) + chess_prepare_for_pipelining chess_loop_range(2, ) { + bfloat16 *__restrict pC1 = out + (z)*MMUL::size_C; + bfloat16 *__restrict pC2 = out + ((z + 1)) * MMUL::size_C; + bfloat16 *__restrict pC3 = out + ((z + 2)) * MMUL::size_C; + bfloat16 *__restrict pC4 = out + ((z + 3)) * MMUL::size_C; + + for (unsigned j = 0; j < colB; j += 4) { + const bfloat16 *__restrict pB1 = b_in + (j)*colA * MMUL::size_B; + const bfloat16 *__restrict pB2 = b_in + ((j + 1)) * colA * MMUL::size_B; + const bfloat16 *__restrict pB3 = b_in + ((j + 2)) * colA * MMUL::size_B; + const bfloat16 *__restrict pB4 = b_in + ((j + 3)) * colA * MMUL::size_B; + + // Load A from 4x4 block format: read two 4x4 blocks, interleave to 4x8 + // For A sub-tile [z, i=0]: read C[rb=z, cb=0] and C[rb=z, cb=1] + auto load_A_4x4 = + [&](unsigned rb, + unsigned kb) -> aie::vector { + const bfloat16 *pLo = + g_in + rb * size_C_blk + (2 * kb) * col_block_stride; + const bfloat16 *pHi = + g_in + rb * size_C_blk + (2 * kb + 1) * col_block_stride; + aie::vector lo = aie::load_v<16>(pLo); + aie::vector hi = aie::load_v<16>(pHi); + // interleave_zip with step=4: takes alternating groups of 4 from lo, + // hi lo = [r0c0..3 r1c0..3 r2c0..3 r3c0..3] hi = [r0c4..7 r1c4..7 + // r2c4..7 r3c4..7] result_lo = [r0c0..3 r0c4..7 r1c0..3 r1c4..7] + // (rows 0-1, 8 cols) result_hi = [r2c0..3 r2c4..7 r3c0..3 r3c4..7] + // (rows 2-3, 8 cols) + auto [zlo, zhi] = aie::interleave_zip(lo, hi, 4); + return aie::concat(zlo, zhi); + }; + + aie::vector A0 = load_A_4x4(z, 0); + aie::vector A1 = load_A_4x4(z + 1, 0); + aie::vector A2 = load_A_4x4(z + 2, 0); + aie::vector A3 = load_A_4x4(z + 3, 0); + + aie::vector B0, B1, B2, B3; + B0 = aie::load_v(pB1); + B1 = aie::load_v(pB2); + B2 = aie::load_v(pB3); + B3 = aie::load_v(pB4); + pB1 += MMUL::size_B; + pB2 += MMUL::size_B; + pB3 += MMUL::size_B; + pB4 += MMUL::size_B; + + aie::vector acc_C00 = + aie::load_v(pC1); + aie::vector acc_C01 = + aie::load_v(pC1 + MMUL::size_C * rowA); + aie::vector acc_C02 = + aie::load_v(pC1 + 2 * MMUL::size_C * rowA); + aie::vector acc_C03 = + aie::load_v(pC1 + 3 * MMUL::size_C * rowA); + + aie::vector acc_C10 = + aie::load_v(pC2); + aie::vector acc_C11 = + aie::load_v(pC2 + MMUL::size_C * rowA); + aie::vector acc_C12 = + aie::load_v(pC2 + 2 * MMUL::size_C * rowA); + aie::vector acc_C13 = + aie::load_v(pC2 + 3 * MMUL::size_C * rowA); + + aie::vector acc_C20 = + aie::load_v(pC3); + aie::vector acc_C21 = + aie::load_v(pC3 + MMUL::size_C * rowA); + aie::vector acc_C22 = + aie::load_v(pC3 + 2 * MMUL::size_C * rowA); + aie::vector acc_C23 = + aie::load_v(pC3 + 3 * MMUL::size_C * rowA); + + aie::vector acc_C30 = + aie::load_v(pC4); + aie::vector acc_C31 = + aie::load_v(pC4 + MMUL::size_C * rowA); + aie::vector acc_C32 = + aie::load_v(pC4 + 2 * MMUL::size_C * rowA); + aie::vector acc_C33 = + aie::load_v(pC4 + 3 * MMUL::size_C * rowA); + + MMUL C00(acc_C00); + MMUL C01(acc_C01); + MMUL C02(acc_C02); + MMUL C03(acc_C03); + MMUL C10(acc_C10); + MMUL C11(acc_C11); + MMUL C12(acc_C12); + MMUL C13(acc_C13); + MMUL C20(acc_C20); + MMUL C21(acc_C21); + MMUL C22(acc_C22); + MMUL C23(acc_C23); + MMUL C30(acc_C30); + MMUL C31(acc_C31); + MMUL C32(acc_C32); + MMUL C33(acc_C33); + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + C02.mac(A0, B2); + C03.mac(A0, B3); + C12.mac(A1, B2); + C13.mac(A1, B3); + C20.mac(A2, B0); + C21.mac(A2, B1); + C30.mac(A3, B0); + C31.mac(A3, B1); + C22.mac(A2, B2); + C23.mac(A2, B3); + C32.mac(A3, B2); + C33.mac(A3, B3); + + for (unsigned i = 1; i < colA; ++i) { + A0 = load_A_4x4(z, i); + A1 = load_A_4x4(z + 1, i); + A2 = load_A_4x4(z + 2, i); + A3 = load_A_4x4(z + 3, i); + + B0 = aie::load_v(pB1); + B1 = aie::load_v(pB2); + B2 = aie::load_v(pB3); + B3 = aie::load_v(pB4); + pB1 += MMUL::size_B; + pB2 += MMUL::size_B; + pB3 += MMUL::size_B; + pB4 += MMUL::size_B; + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + C02.mac(A0, B2); + C03.mac(A0, B3); + C12.mac(A1, B2); + C13.mac(A1, B3); + C20.mac(A2, B0); + C21.mac(A2, B1); + C30.mac(A3, B0); + C31.mac(A3, B1); + C22.mac(A2, B2); + C23.mac(A2, B3); + C32.mac(A3, B2); + C33.mac(A3, B3); + } + + aie::store_v(pC1, C00.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C01.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C02.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C03.template to_vector()); + pC1 += MMUL::size_C * rowA; + + aie::store_v(pC2, C10.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C11.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C12.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C13.template to_vector()); + pC2 += MMUL::size_C * rowA; + + aie::store_v(pC3, C20.template to_vector()); + pC3 += MMUL::size_C * rowA; + aie::store_v(pC3, C21.template to_vector()); + pC3 += MMUL::size_C * rowA; + aie::store_v(pC3, C22.template to_vector()); + pC3 += MMUL::size_C * rowA; + aie::store_v(pC3, C23.template to_vector()); + pC3 += MMUL::size_C * rowA; + + aie::store_v(pC4, C30.template to_vector()); + pC4 += MMUL::size_C * rowA; + aie::store_v(pC4, C31.template to_vector()); + pC4 += MMUL::size_C * rowA; + aie::store_v(pC4, C32.template to_vector()); + pC4 += MMUL::size_C * rowA; + aie::store_v(pC4, C33.template to_vector()); + pC4 += MMUL::size_C * rowA; + } + } + + event1(); +} + +void zero_fill_gp_bf16(bfloat16 *c_out) { + // Buffer shape: [lqp, dv] + zero_vectorized(c_out); +} + +void zero_fill_sp_bf16(bfloat16 *c_out) { + // Buffer shape: [lqp, 1] + zero_vectorized(c_out); +} + +void zero_fill_g_bf16(bfloat16 *c_out) { + // Buffer shape: [lqp, lkp] + zero_vectorized(c_out); +} + +void neg_inf_fill_up_bf16(bfloat16 *c_out) { + // Buffer shape: [lqp, 1] + neg_inf_vectorized(c_out); +} + +// Scale G by 1/sqrt(dk_full) in-place. +// G is column-major 4x4 block tiled: [lqp, lkp]. +void scale_g_bf16(bfloat16 *g) { + constexpr int VecLen = 16; + constexpr int num_elems = lqp * lkp; + bfloat16 scale_val = (bfloat16)inv_sqrt_dk; + aie::vector scale_vec = + aie::broadcast(scale_val); + bfloat16 *__restrict pG = g; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector v = aie::load_v(pG); + aie::accum acc = aie::mul(v, scale_vec); + aie::store_v(pG, acc.to_vector()); + pG += VecLen; + } +} + +// Row-wise max of G matrix. +// G is column-major 4x4 block tiled. +// VecLen=16 reads one full 4x4 block (4 rows x 4 cols). +// Within a 16-wide vector: elements [0..3]=row0, [4..7]=row1, [8..11]=row2, +// [12..15]=row3. Since aie::vector is not supported on AIE2, +// use scalar element access for per-row reduction. +void max_g_bf16(bfloat16 *in, bfloat16 *out) { + constexpr int VecLen = 16; + constexpr int BlockSize = 16; // 4x4 block + constexpr int ColsPerBlock = 4; + constexpr int RowsPerBlock = 4; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = + lqp * ColsPerBlock; // stride between column blocks + + // Use bf16 lowest (0xff7f) instead of -inf to avoid NaN propagation. + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + + bfloat16 *__restrict pOut = out; + for (int rb = 0; rb < row_blocks; rb++) { + aie::vector max_vec = + aie::broadcast(lowest_val); + int base = rb * BlockSize; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = + aie::load_v(in + base + cb * block_stride); + max_vec = aie::max(max_vec, v); + } + // Extract per-row max via scalar access. + // Row i occupies elements [i*4 .. i*4+3] in the 16-wide vector. + for (int row = 0; row < RowsPerBlock; row++) { + bfloat16 m = max_vec[row * ColsPerBlock]; + for (int c = 1; c < ColsPerBlock; c++) { + bfloat16 val = max_vec[row * ColsPerBlock + c]; + if (val > m) + m = val; + } + pOut[row] = m; + } + pOut += RowsPerBlock; + } +} + +void maximum_up_u_bf16(bfloat16 *up, bfloat16 *u) { + // u = max(u, up) + // Buffer shape: [lqp, 1] + constexpr int VecLen = 16; + constexpr int num_elems = lqp; + bfloat16 *__restrict pu = u; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector up_temp = aie::load_v(up + i); + aie::vector u_temp = aie::load_v(pu); + u_temp = aie::max(up_temp, u_temp); + aie::store_v(pu, u_temp); + pu += VecLen; + } +} + +// G = exp(G - u) in-place. G is column-major 4x4 block tiled. +// VecLen=16 processes one full 4x4 block (4 rows x 4 cols). +// Uses LUT-based exp. +void exp_g_minus_u(bfloat16 *u, bfloat16 *g) { + constexpr int VecLen = 16; + constexpr int BlockSize = 16; + constexpr int ColsPerBlock = 4; + constexpr int RowsPerBlock = 4; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + for (int rb = 0; rb < row_blocks; rb++) { + // Build 16-wide u vector: 4 rows x 4 cols, each row's u broadcast to its + // 4 column elements. Use scalar set since vector not supported. + int row_start = rb * RowsPerBlock; + aie::vector u_vec = aie::zeros(); + for (int row = 0; row < RowsPerBlock; row++) { + bfloat16 uval = u[row_start + row]; + for (int c = 0; c < ColsPerBlock; c++) { + u_vec[row * ColsPerBlock + c] = uval; + } + } + + int base = rb * BlockSize; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(g + off); + v = aie::sub(v, u_vec); + // LUT-based exp: getExpBf16 takes v16bfloat16, returns v16accfloat + aie::vector exp_val = to_v16bfloat16(getExpBf16(v)); + aie::store_v(g + off, exp_val); + } + } +} + +// r = exp(up - u). Uses LUT-based exp. +void exp_up_minus_u(bfloat16 *up, bfloat16 *u, bfloat16 *r) { + constexpr int VecLen = 16; + constexpr int num_elems = lqp; + bfloat16 *__restrict pr = r; + bfloat16 *__restrict pu = u; + bfloat16 *__restrict pup = up; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector uTemp = aie::load_v(pu); + aie::vector upTemp = aie::load_v(pup); + aie::vector diff = aie::sub(upTemp, uTemp); + // LUT-based exp + aie::vector exp_val = to_v16bfloat16(getExpBf16(diff)); + aie::store_v(pr, exp_val); + pr += VecLen; + pu += VecLen; + pup += VecLen; + } +} + +// Gp = Gp * r (per-row scaling). +// Gp is column-major 4x4 block tiled: [lqp, dv]. +void mul_r_gp(bfloat16 *r, bfloat16 *gp) { + constexpr int VecLen = 16; + constexpr int BlockSize = 16; // 4x4 block + constexpr int ColsPerBlock = 4; + constexpr int RowsPerBlock = 4; + constexpr int col_blocks = dv / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = + lqp * ColsPerBlock; // stride between column blocks + + for (int rb = 0; rb < row_blocks; rb++) { + // Build 16-wide r vector: 4 rows x 4 cols, each row's r broadcast + int row_start = rb * RowsPerBlock; + aie::vector r_vec = aie::zeros(); + for (int row = 0; row < RowsPerBlock; row++) { + bfloat16 rval = r[row_start + row]; + for (int c = 0; c < ColsPerBlock; c++) { + r_vec[row * ColsPerBlock + c] = rval; + } + } + + int base = rb * BlockSize; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(gp + off); + aie::accum acc = aie::mul(v, r_vec); + aie::store_v(gp + off, acc.to_vector()); + } + } +} + +// s = sum(G, axis=-1, keepdims=True). +// G is column-major 4x4 block tiled. +void sum_g(bfloat16 *g, bfloat16 *s) { + constexpr int VecLen = 16; + constexpr int BlockSize = 16; + constexpr int ColsPerBlock = 4; + constexpr int RowsPerBlock = 4; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + bfloat16 *__restrict ps = s; + for (int rb = 0; rb < row_blocks; rb++) { + // Accumulate sum across column blocks for 4 rows + aie::accum sum_acc = aie::zeros(); + int base = rb * BlockSize; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = + aie::load_v(g + base + cb * block_stride); + sum_acc = aie::add(sum_acc, v); + } + // Reduce each 4-element row slice via scalar access. + aie::vector sum_v = sum_acc.to_vector(); + for (int row = 0; row < RowsPerBlock; row++) { + float row_sum = 0.0f; + for (int c = 0; c < ColsPerBlock; c++) { + row_sum += sum_v[row * ColsPerBlock + c]; + } + ps[row] = (bfloat16)row_sum; + } + ps += RowsPerBlock; + } +} + +void accum_sp_r_s(bfloat16 *sp, bfloat16 *r, bfloat16 *s) { + // s += sp * r + // Buffer shape: [lqp, 1] + constexpr int VecLen = 16; + constexpr int num_elems = lqp; + bfloat16 *__restrict pr = r; + bfloat16 *__restrict ps = s; + bfloat16 *__restrict psp = sp; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector rTemp = aie::load_v(pr); + aie::vector spTemp = aie::load_v(psp); + aie::accum accTemp = aie::mul(rTemp, spTemp); + accTemp = aie::add(accTemp, aie::load_v(ps)); + aie::vector sTemp = to_v16bfloat16(accTemp); + aie::store_v(ps, sTemp); + pr += VecLen; + ps += VecLen; + psp += VecLen; + } +} + +void vector_copy_32elems(const int offset, const bfloat16 *__restrict inputs, + bfloat16 *__restrict outputs) { + constexpr int VecLen = 16; + constexpr int num_elems = lqp; + const bfloat16 *__restrict pIn = inputs; + bfloat16 *__restrict pOut = outputs + offset; + for (unsigned j = 0; j < num_elems / VecLen; j++) { + aie::vector vec = aie::load_v(pIn); + pIn += VecLen; + aie::store_v(pOut, vec); + pOut += VecLen; + } +} + +// Gp = Gp / sp (per-row normalization). +// Gp is column-major 4x4 block tiled: [lqp, dv]. +// Uses aie::div (AIE2-compatible, no aie::inv). +void div_gp_sp(bfloat16 *sp, bfloat16 *gp) { + constexpr int VecLen = 16; + constexpr int BlockSize = 16; // 4x4 block + constexpr int ColsPerBlock = 4; + constexpr int RowsPerBlock = 4; + constexpr int col_blocks = dv / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = + lqp * ColsPerBlock; // stride between column blocks + + for (int rb = 0; rb < row_blocks; rb++) { + // Build 16-wide sp vector via scalar access + int row_start = rb * RowsPerBlock; + aie::vector sp_vec = aie::zeros(); + for (int row = 0; row < RowsPerBlock; row++) { + bfloat16 spval = sp[row_start + row]; + for (int c = 0; c < ColsPerBlock; c++) { + sp_vec[row * ColsPerBlock + c] = spval; + } + } + + int base = rb * BlockSize; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(gp + off); + v = aie::div(v, sp_vec); + aie::store_v(gp + off, v); + } + } +} + +// Fused softmax: delegates to existing kernels. +// On return: up=new_max, sp=sum(exp(G)), r=rescale_factor, G=exp(G-max). +void fused_softmax(bfloat16 *g, bfloat16 *up, bfloat16 *sp, bfloat16 *r) { + scale_g_bf16(g); + max_g_bf16(g, r); + maximum_up_u_bf16(up, r); + exp_g_minus_u(r, g); + exp_up_minus_u(up, r, sp); + vector_copy_32elems(0, r, up); + vector_copy_32elems(0, sp, r); + sum_g(g, sp); +} + +void add_gp_g(bfloat16 *gp, bfloat16 *g) { + constexpr int VecLen = 16; + constexpr int num_elems = lqp * dv; + bfloat16 *__restrict gp_ptr = gp; + bfloat16 *__restrict g_ptr = g; + for (unsigned j = 0; j < num_elems / VecLen; j++) { + aie::vector gp_vec = aie::load_v(gp_ptr); + aie::vector g_vec = aie::load_v(g_ptr); + aie::accum acc(gp_vec); + acc = aie::add(acc, g_vec); + aie::store_v(g_ptr, acc.to_vector()); + gp_ptr += VecLen; + g_ptr += VecLen; + } +} + +// Apply causal mask to QK scores in-place. +// G is column-major 4x4 block tiled. +// Uses scalar access since aie::vector is not supported on AIE2. +void apply_causal_mask(bfloat16 *g, int32_t q_block_idx, int32_t kv_block_idx) { + uint16_t neg_inf_u16 = (uint16_t)0xff80; + bfloat16 neg_inf_val = *(bfloat16 *)&neg_inf_u16; + + // 1. Block above diagonal: all masked -> fill with -inf + if (kv_block_idx > q_block_idx) { + constexpr int VecLen = 16; + aie::vector neg_inf_vec = + aie::broadcast(neg_inf_val); + bfloat16 *p = g; + for (int i = 0; i < lqp * lkp; i += VecLen) { + aie::store_v(p, neg_inf_vec); + p += VecLen; + } + return; + } + + // 2. Block below diagonal: no masking needed + if (kv_block_idx < q_block_idx) { + return; + } + + // 3. Diagonal block (kv_block_idx == q_block_idx): + // Use scalar writes for per-element causal masking. + constexpr int BlkDim = 4; + + for (int row = 0; row < lqp; row++) { + int mask_start = row + 1; + int row_blk = row / BlkDim; + int row_in = row % BlkDim; + + for (int col_blk = 0; col_blk < lkp / BlkDim; col_blk++) { + int col_start = col_blk * BlkDim; + int off = col_blk * (lqp * BlkDim) + row_blk * (BlkDim * BlkDim) + + row_in * BlkDim; + + if (col_start >= mask_start) { + // Entire sub-row masked + for (int c = 0; c < BlkDim; c++) { + g[off + c] = neg_inf_val; + } + } else if (col_start + BlkDim > mask_start) { + // Partial: mask columns >= mask_start + for (int c = 0; c < BlkDim; c++) { + if (col_start + c >= mask_start) { + g[off + c] = neg_inf_val; + } + } + } + // else: unmasked, leave unchanged + } + } +} + +} // extern "C" diff --git a/programming_examples/flash_attention/packet_switched/attn_npu1.py b/programming_examples/flash_attention/packet_switched/attn_npu1.py new file mode 100644 index 000000000..ebf602a64 --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/attn_npu1.py @@ -0,0 +1,1341 @@ +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +"""Flash attention with packet-switched Q/K routing (NPU1 / AIE2). + +NPU1 variant of the packet-switched flash attention design. Uses +packet-switched DMA channels (channel_type="dma_packet") to time-multiplex +Q and K data through shared compute tile S2MM DMA channels. + +NPU1-specific differences from the NPU2 variant: + - mmul<4,8,4> (M=4, K_mmul=8) instead of mmul<8,8,8> + - num_heads_per_unroll=1 (4x4 tile array fits 1 head) + - LUT-based exponential (no native aie::exp2 on AIE2) + - G layout conversion via interleave_zip in kernel (4x4 C-blocks to 4x8 A-blocks) + - k-major B-block indexing in DMA (matching attn_npu1.cc kernel) + - Output format: xclbin (not elf) + +Channel routing (same packet-switched structure as NPU2): + L2ToL1Chan1 (Q): dma_packet — broadcast to [num_q_tiles, num_cascade_stages] + L2ToL1Chan2 (K): dma_packet — broadcast to [num_q_tiles, num_cascade_stages] + L2ToL1Chan3 (V): dma_stream — circuit-switched per cascade stage +""" + +import argparse +from math import cos, sin, sqrt, exp +import numpy as np + +import air +from air.ir import * +from air.dialects.affine import apply as affine_apply +from air.dialects.air import * +from air.dialects.arith import ConstantOp +from air.dialects.memref import AllocOp, CollapseShapeOp, DeallocOp, load, store +from air.dialects.func import FuncOp, CallOp +from air.dialects.scf import for_, yield_ +from air.dialects import scf, affine, arith + +range_ = for_ + + +@module_builder +def build_module( + lk=12288, + lkp=96, + lq=512, + lqp=128, + dk=64, + dv=64, + num_q_tiles=4, + num_cascade_stages=4, + num_heads=12, + num_kv_heads=None, + causal=False, +): + """Build the attention module using Python bindings + + Args: + lk: Total sequence length for K/V matrices (default: 12288) + lkp: Chunk size for K/V processing per AIE tile (default: 96) + lq: Total sequence length for Q matrix (default: 512) + lqp: Chunk size for Q processing per launch iteration (default: 128) + dk: Key dimension (default: 64) + dv: Value dimension (default: 64) + num_q_tiles: Number of tiles to partition Q chunk (lqp) into (default: 4) + num_cascade_stages: Number of cascade pipeline stages (default: 4) + num_heads: Number of Q attention heads (default: 12) + num_kv_heads: Number of K/V heads (default: num_heads for MHA, < num_heads for GQA) + causal: Enable causal masking (default: False) + """ + if num_kv_heads is None: + num_kv_heads = num_heads # MHA: every Q head has its own KV head + + # Validate divisibility requirements + assert lq % lqp == 0, f"lq ({lq}) must be divisible by lqp ({lqp})" + assert ( + lqp % num_q_tiles == 0 + ), f"lqp ({lqp}) must be divisible by num_q_tiles ({num_q_tiles})" + assert lk % lkp == 0, f"lk ({lk}) must be divisible by lkp ({lkp})" + assert ( + lk % (lkp * num_cascade_stages) == 0 + ), f"lk ({lk}) must be divisible by lkp * num_cascade_stages ({lkp * num_cascade_stages})" + tile_size_q_check = lqp // num_q_tiles + enable_shared_buffers = lkp == dk and tile_size_q_check <= lkp + if causal: + assert lq == lk, f"Causal masking requires lq == lk, got lq={lq}, lk={lk}" + assert lkp == dk, ( + f"Causal masking requires lkp == dk (enable_shared_buffers) for " + f"the prefix+suffix BD collapse to produce infinite-loop DMAs " + f"(no PDI reset between iterations). Got lkp={lkp}, dk={dk}." + ) + tile_size_q = lqp // num_q_tiles + assert ( + tile_size_q == lkp + ), f"Causal masking requires tile_size_q == lkp, got {tile_size_q} vs {lkp}" + assert ( + num_heads % 1 == 0 + ), f"num_heads ({num_heads}) must be positive (NPU1: 1 head per segment unroll)" + assert num_kv_heads > 0, "num_kv_heads must be positive" + assert ( + num_heads % num_kv_heads == 0 + ), f"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + gqa_group_size = num_heads // num_kv_heads + + bf16 = Type.parse("bf16") + i32 = IntegerType.get_signless(32) + index_type = IndexType.get() + + # Architecture-specific matrix multiplication dimensions + # AIE2 (NPU1) uses mmul<4,8,4>: M=4 (row), K_mmul=8, N=4 (col) + M = 4 + K_mmul = 8 + mmul_m = M + mmul_k = K_mmul + mmul_n = M + + # NPU1: 4x4 tile array fits 1 head at a time + num_heads_per_unroll = 1 + num_head_groups = num_heads // num_heads_per_unroll + + # Derived parameters + num_chunks = lk // lkp + chunks_per_stage = num_chunks // num_cascade_stages + num_lq_iters = lq // lqp # Total Q iterations + # Q iteration at launch level for both causal and non-causal. + # Keeping Q at launch level avoids DMA task ordering conflicts: when Q + # iterates on-device, Q and K share the same compute-tile S2MM channel, + # and getRepeatCounts groups them into sequential tasks [Q×N, K×M] + # instead of interleaved [Q, K×M, Q, K×M, ...], causing deadlock. + # For causal masking, the launch Q index is threaded through to the herd + # body for the block index computation. + launch_lq_iters = num_lq_iters + device_lq_iters = 1 + tile_size_q = lqp // num_q_tiles # Tile size within each lqp chunk + + # Memory spaces: L1 = 2 : i32, L2 = 1 : i32 + l1_space = IntegerAttr.get(i32, 2) # L1 uses memory space 2 + l2_space = IntegerAttr.get(i32, 1) # L2 uses memory space 1 + + # L1 MemRefTypes (memory space 2 : i32) - used in herd bodies + memref_lqp_dv_l1 = MemRefType.get([tile_size_q, dk], bf16, memory_space=l1_space) + memref_lqp_l1 = MemRefType.get([tile_size_q, 1], bf16, memory_space=l1_space) + memref_lqp_lkp_l1 = MemRefType.get([tile_size_q * lkp], bf16, memory_space=l1_space) + memref_dv_lkp_l1 = MemRefType.get([lkp, dk], bf16, memory_space=l1_space) + memref_g_shared_l1 = MemRefType.get([tile_size_q, lkp], bf16, memory_space=l1_space) + + # L2 MemRefTypes (memory space 1 : i32) - segment allocations + memref_lqp_dk_l2 = MemRefType.get([tile_size_q, dk], bf16, memory_space=l2_space) + memref_dk_lkp_l2 = MemRefType.get([lkp, dk], bf16, memory_space=l2_space) + memref_lkp_dv_l2 = MemRefType.get([lkp, dk], bf16, memory_space=l2_space) + memref_output_lqp_dv_l2 = MemRefType.get( + [lqp, dk], bf16, memory_space=l2_space + ) # Per-iteration output buffer + + # L3 MemRefTypes (no memory space annotation = default L3) - with head dimension + memref_input_q_lq_dk = MemRefType.get([num_heads, lq, dk], bf16) + memref_output_lq_dv = MemRefType.get([num_heads, lq, dk], bf16) + memref_input_k_dk_lk = MemRefType.get([num_kv_heads, lk, dk], bf16) + memref_input_v_lk_dv = MemRefType.get([num_kv_heads, lk, dk], bf16) + memref_input_m_lq_lk = MemRefType.get([num_heads, lq, lk], bf16) + + # Helper function to create external function declarations + def external_func(name, inputs, outputs=None, link_with=None, visibility="private"): + if outputs is None: + outputs = [] + func_type = FunctionType.get(inputs, outputs) + func = FuncOp(name=name, type=func_type, visibility=visibility) + func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + if link_with: + func.attributes["link_with"] = StringAttr.get(link_with) + return func + + # External function declarations + external_func("zero_fill_gp_bf16", [memref_lqp_dv_l1], link_with="attn_npu1.o") + external_func("zero_fill_sp_bf16", [memref_lqp_l1], link_with="attn_npu1.o") + external_func("zero_fill_g_bf16", [memref_lqp_lkp_l1], link_with="attn_npu1.o") + external_func("neg_inf_fill_up_bf16", [memref_lqp_l1], link_with="attn_npu1.o") + external_func( + "matmul_a_b_bf16", + [memref_lqp_dv_l1, memref_dv_lkp_l1, memref_lqp_lkp_l1], + link_with="attn_npu1.o", + ) + external_func( + "matmul_g_b_bf16", + [memref_lqp_lkp_l1, memref_dv_lkp_l1, memref_lqp_dv_l1], + link_with="attn_npu1.o", + ) + external_func( + "max_g_bf16", [memref_lqp_lkp_l1, memref_lqp_l1], link_with="attn_npu1.o" + ) + external_func( + "fused_softmax", + [memref_lqp_lkp_l1, memref_lqp_l1, memref_lqp_l1, memref_lqp_l1], + link_with="attn_npu1.o", + ) + external_func( + "maximum_up_u_bf16", [memref_lqp_l1, memref_lqp_l1], link_with="attn_npu1.o" + ) + external_func( + "exp_g_minus_u", [memref_lqp_l1, memref_lqp_lkp_l1], link_with="attn_npu1.o" + ) + external_func( + "exp_up_minus_u", + [memref_lqp_l1, memref_lqp_l1, memref_lqp_l1], + link_with="attn_npu1.o", + ) + external_func( + "mul_r_gp", [memref_lqp_l1, memref_lqp_dv_l1], link_with="attn_npu1.o" + ) + external_func("sum_g", [memref_lqp_lkp_l1, memref_lqp_l1], link_with="attn_npu1.o") + external_func( + "accum_sp_r_s", + [memref_lqp_l1, memref_lqp_l1, memref_lqp_l1], + link_with="attn_npu1.o", + ) + external_func( + "vector_copy_32elems", + [i32, memref_lqp_l1, memref_lqp_l1], + link_with="attn_npu1.o", + ) + external_func( + "copy_tile", [memref_dv_lkp_l1, memref_lqp_dv_l1], link_with="attn_npu1.o" + ) + external_func( + "div_gp_sp", [memref_lqp_l1, memref_lqp_dv_l1], link_with="attn_npu1.o" + ) + external_func( + "vector_copy_swizzle_elems", + [i32, memref_lqp_lkp_l1, memref_lqp_lkp_l1], + link_with="attn_npu1.o", + ) + external_func( + "vector_copy_unswizzle_elems", + [i32, memref_lqp_lkp_l1, memref_lqp_lkp_l1], + link_with="attn_npu1.o", + ) + external_func( + "add_gp_g", [memref_lqp_dv_l1, memref_lqp_dv_l1], link_with="attn_npu1.o" + ) + # Local i32 buffer for passing block indices to apply_causal_mask + # (unconditional i32 stores, kernel handles conditionals) + memref_2xi32_l1 = MemRefType.get([2], i32, memory_space=l1_space) + if causal: + external_func( + "apply_causal_mask", + [memref_lqp_lkp_l1, i32, i32], + link_with="attn_npu1.o", + ) + + # Channel declarations - use num_heads_per_unroll (1) for segment unroll + Channel("L3ToL2Chan1", size=[num_heads_per_unroll, num_cascade_stages]) + Channel("L3ToL2Chan2", size=[num_heads_per_unroll, num_cascade_stages]) + chan_l2_to_l1_2 = Channel( + "L2ToL1Chan2", + size=[1, num_cascade_stages], + broadcast_shape=[num_q_tiles, num_cascade_stages], + ) + chan_l2_to_l1_2.attributes["channel_type"] = StringAttr.get("dma_packet") + if not enable_shared_buffers: + chan_l2_to_l1_1 = Channel( + "L2ToL1Chan1", + size=[num_q_tiles, 1], + broadcast_shape=[num_q_tiles, num_cascade_stages], + ) + chan_l2_to_l1_1.attributes["channel_type"] = StringAttr.get("dma_packet") + chan_l2_to_l1_3 = Channel( + "L2ToL1Chan3", + size=[1, num_cascade_stages], + broadcast_shape=[num_q_tiles, num_cascade_stages], + ) + Channel("L1ToL2Chan1", size=[num_q_tiles, 1]) + Channel("L2ToL3Chan1", size=[num_heads_per_unroll]) + chan_cascade = Channel("cascade", size=[num_q_tiles, num_cascade_stages - 1]) + chan_cascade.attributes["channel_type"] = StringAttr.get("cascade") + + # Main attention function + @FuncOp.from_py_func( + memref_input_q_lq_dk, + memref_input_k_dk_lk, + memref_input_v_lk_dv, + memref_input_m_lq_lk, + memref_output_lq_dv, + ) + def attention_bf16(arg0, arg1, arg2, arg3, arg4): + c_launch_lq = ConstantOp(index_type, launch_lq_iters) + c_num_head_groups = ConstantOp(index_type, num_head_groups) + + # Non-causal: launch iterates Q blocks at host level (no BD chain limit) + # Causal: launch size 1, Q iteration inside herd (device-local q_block) + @launch( + operands=[arg0, arg1, arg2, arg4], sizes=[c_launch_lq, c_num_head_groups] + ) + def launch_body(arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12): + # arg5 = Q iteration index (0..launch_lq_iters-1), arg6 = head group + c0 = ConstantOp(index_type, 0) + c1 = ConstantOp(index_type, 1) + + # Compute actual head index from head group + # NPU1: num_heads_per_unroll=1, so head_base = arg6 * 1 = arg6 + affine_map_head_base = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(num_heads_per_unroll), + ) + ], + ) + head_base = affine_apply(affine_map_head_base, [arg6]) + + # GQA: compute KV head index from Q head index + # kv_head = q_head // gqa_group_size + if gqa_group_size == 1: + # MHA: kv_head == q_head + kv_head_base = head_base + else: + affine_map_kv_head = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_floor_div( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(gqa_group_size), + ) + ], + ) + kv_head_base = affine_apply(affine_map_kv_head, [head_base]) + + # Affine map for Q tile partitioning within lqp chunk + affine_map_tileq = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(tile_size_q) + ) + ], + ) + # Affine map for launch offset: arg5 * lqp * dk + affine_map_launch_offset = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lqp * dk) + ) + ], + ) + # Affine map for Q head offset: head * lq * dk + launch_offset + affine_map_q_head_offset = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lq * dk) + ), + AffineSymbolExpr.get(1), + ) + ], + ) + # Affine map for K head offset: head * lk * dk + row_offset * dk + # K stored as [num_kv_heads, lk, dk] (row-major) + affine_map_head_row = AffineMap.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lk * dk) + ), + AffineExpr.get_mul( + AffineSymbolExpr.get(1), AffineConstantExpr.get(dk) + ), + ) + ], + ) + # Affine map for V head offset: head * lk * dv + affine_map_v_head_offset = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(lk * dv) + ) + ], + ) + + # Combined Q/K/V/output DMA loop — one iteration per q_iter + # Must be a single loop so Q, K, V, and output are interleaved in + # the correct order matching the segment's consumption pattern. + c_device_lq_iters = ConstantOp(index_type, device_lq_iters) + for lq_it in range_(c0, c_device_lq_iters, c1): + # Combine launch Q index (arg5) + device Q index (lq_it) + # Non-causal: arg5 varies, lq_it=0. Causal: arg5=0, lq_it varies. + q_iter_global = arith.AddIOp(arg5, lq_it) + + # (A) Q: L3→L2 for this q_iter + par_1 = scf.ForallOp( + lower_bounds=[0], upper_bounds=[num_cascade_stages], steps=[1] + ) + with InsertionPoint(par_1.body): + tile_offset = affine_apply( + affine_map_tileq, [par_1.induction_variables[0]] + ) + launch_offset = affine_apply( + affine_map_launch_offset, [q_iter_global.result] + ) + # Head 0 in group (head_base) + q_head0_off = affine_apply( + affine_map_q_head_offset, [head_base, launch_offset] + ) + ChannelPut( + "L3ToL2Chan1", + arg9, + indices=[c0, par_1.induction_variables[0]], + offsets=[tile_offset, q_head0_off], + sizes=[tile_size_q, dk], + strides=[dk, 1], + ) + scf.InParallelOp() + + # (B) K: L3→L2 for this q_iter (same K data re-sent each iter) + for i in range(num_cascade_stages): + row_off = ConstantOp(index_type, i * chunks_per_stage * lkp) + k_head0_off = affine_apply( + affine_map_head_row, [kv_head_base, row_off] + ) + ChannelPut( + "L3ToL2Chan1", + arg10, + indices=[c0, i], + offsets=[0, 0, k_head0_off], + sizes=[chunks_per_stage, lkp, dk], + strides=[lkp * dk, dk, 1], + ) + + # (C) V: L3→L2 for this q_iter (same V data re-sent each iter) + for i in range(num_cascade_stages): + v_head0_off = affine_apply(affine_map_v_head_offset, [kv_head_base]) + ChannelPut( + "L3ToL2Chan2", + arg11, + indices=[c0, i], + offsets=[0, i * chunks_per_stage * lkp, v_head0_off], + sizes=[chunks_per_stage, lkp, dv], + strides=[lkp * dv, dv, 1], + ) + + # (D) Output: L2→L3 for this q_iter + launch_offset_out = affine_apply( + affine_map_launch_offset, [q_iter_global.result] + ) + out_head0_off = affine_apply( + affine_map_q_head_offset, [head_base, launch_offset_out] + ) + ChannelGet( + "L2ToL3Chan1", + arg12, + indices=[c0], + offsets=[0, out_head0_off], + sizes=[lqp, dk], + strides=[dk, 1], + ) + + yield_([]) + + # Segment unrolls over 1 head (NPU1: 4x4 array fits 1 head) + c_num_heads_unroll = ConstantOp(index_type, num_heads_per_unroll) + c_dummy_size = ConstantOp(index_type, 1) + + # In causal mode, pass launch Q index through segment to herd + # for causal block index computation. After runtime loop tiling + # (runtime_loop_tiling_sizes=[1,1]), arg5 becomes a constant in + # each tiled iteration, so the RTP write in airrt-to-npu succeeds. + seg_operands = [] + + @segment( + name="attention_seg", + operands=seg_operands, + sizes=[c_num_heads_unroll, c_dummy_size], + ) + def segment_body(*seg_args): + head_idx, dummy_idx, head_size, dummy_size = seg_args[:4] + launch_q_idx = seg_args[4] if (causal and len(seg_args) > 4) else None + # L2 allocations + if enable_shared_buffers: + alloc = alloc_col1 = alloc_col2 = alloc_col3 = None + else: + alloc = AllocOp(memref_lqp_dk_l2, [], []) + alloc_col1 = AllocOp(memref_lqp_dk_l2, [], []) + alloc_col2 = AllocOp(memref_lqp_dk_l2, [], []) + alloc_col3 = AllocOp(memref_lqp_dk_l2, [], []) + alloc_2 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_21 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_22 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_23 = AllocOp(memref_dk_lkp_l2, [], []) + alloc_3 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_31 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_32 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_33 = AllocOp(memref_lkp_dv_l2, [], []) + alloc_5 = AllocOp(memref_output_lqp_dv_l2, [], []) + up = AllocOp(memref_lqp_l1, [], []) + sp = AllocOp(memref_lqp_l1, [], []) + Gp = AllocOp(memref_lqp_dv_l1, [], []) + alloc_6 = AllocOp(memref_lqp_dv_l1, [], []) + if enable_shared_buffers: + G_shared = AllocOp(memref_g_shared_l1, [], []) + QK_shared = AllocOp(memref_dv_lkp_l1, [], []) + else: + G_shared = None + QK_shared = None + # Local counter for causal block index tracking. + # Passed as memref operand (NOT scalar) → no RTP, no herd lock. + causal_counter = AllocOp(memref_2xi32_l1, [], []) if causal else None + + c_num_q_tiles = ConstantOp(index_type, num_q_tiles) + c_num_cascade = ConstantOp(index_type, num_cascade_stages) + c0_seg = ConstantOp(index_type, 0) + c1_seg = ConstantOp(index_type, 1) + c2_seg = ConstantOp(index_type, 2) + c3_seg = ConstantOp(index_type, 3) + + # Q/K/V/output DMA loop over lq_iters (Q iteration moved from launch to device) + q_l2_bufs = ( + [alloc_2, alloc_21, alloc_22, alloc_23] + if enable_shared_buffers + else [alloc, alloc_col1, alloc_col2, alloc_col3] + ) + q_chan = "L2ToL1Chan2" if enable_shared_buffers else "L2ToL1Chan1" + q_idx = lambda col: ( + [c0_seg, col] if enable_shared_buffers else [col, c0_seg] + ) + + c_device_lq_seg = ConstantOp(index_type, device_lq_iters) + for lq_it_seg in range_(c0_seg, c_device_lq_seg, c1_seg): + # (A) Q: L3→L2 gets for this q_iter's 4 tiles + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[0].result, indices=[head_idx, c0_seg] + ) + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[1].result, indices=[head_idx, c1_seg] + ) + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[2].result, indices=[head_idx, c2_seg] + ) + ChannelGet( + "L3ToL2Chan1", q_l2_bufs[3].result, indices=[head_idx, c3_seg] + ) + + # (B) Q: L2→L1 puts for this q_iter's 4 tiles + ChannelPut( + q_chan, + q_l2_bufs[0].result, + indices=q_idx(c0_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_m, dk, 1], + ) + ChannelPut( + q_chan, + q_l2_bufs[1].result, + indices=q_idx(c1_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_m, dk, 1], + ) + ChannelPut( + q_chan, + q_l2_bufs[2].result, + indices=q_idx(c2_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_m, dk, 1], + ) + ChannelPut( + q_chan, + q_l2_bufs[3].result, + indices=q_idx(c3_seg), + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, tile_size_q // mmul_m, mmul_m, mmul_k], + strides=[mmul_k, dk * mmul_m, dk, 1], + ) + + # (C) K/V streaming: L3→L2 + L2→L1 (inner loop) + for arg21 in range_(0, chunks_per_stage, 1): + # Channel gets for K and V - use head_idx + ChannelGet( + "L3ToL2Chan1", alloc_2.result, indices=[head_idx, c0_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_3.result, indices=[head_idx, c0_seg] + ) + ChannelGet( + "L3ToL2Chan1", alloc_21.result, indices=[head_idx, c1_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_31.result, indices=[head_idx, c1_seg] + ) + ChannelGet( + "L3ToL2Chan1", alloc_22.result, indices=[head_idx, c2_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_32.result, indices=[head_idx, c2_seg] + ) + ChannelGet( + "L3ToL2Chan1", alloc_23.result, indices=[head_idx, c3_seg] + ) + ChannelGet( + "L3ToL2Chan2", alloc_33.result, indices=[head_idx, c3_seg] + ) + + # Channel puts for K matrix to L1 + ChannelPut( + "L2ToL1Chan2", + alloc_2.result, + indices=[c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, lkp // mmul_n, mmul_n, mmul_k], + strides=[mmul_k, dk * mmul_n, dk, 1], + ) + ChannelPut( + "L2ToL1Chan2", + alloc_21.result, + indices=[c0_seg, c1_seg], + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, lkp // mmul_n, mmul_n, mmul_k], + strides=[mmul_k, dk * mmul_n, dk, 1], + ) + ChannelPut( + "L2ToL1Chan2", + alloc_22.result, + indices=[c0_seg, c2_seg], + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, lkp // mmul_n, mmul_n, mmul_k], + strides=[mmul_k, dk * mmul_n, dk, 1], + ) + ChannelPut( + "L2ToL1Chan2", + alloc_23.result, + indices=[c0_seg, c3_seg], + offsets=[0, 0, 0, 0], + sizes=[dk // mmul_k, lkp // mmul_n, mmul_n, mmul_k], + strides=[mmul_k, dk * mmul_n, dk, 1], + ) + + # Channel puts for V matrix to L1 + ChannelPut( + "L2ToL1Chan3", + alloc_3.result, + indices=[c0_seg, c0_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_k, dv, 1], + ) + ChannelPut( + "L2ToL1Chan3", + alloc_31.result, + indices=[c0_seg, c1_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_k, dv, 1], + ) + ChannelPut( + "L2ToL1Chan3", + alloc_32.result, + indices=[c0_seg, c2_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_k, dv, 1], + ) + ChannelPut( + "L2ToL1Chan3", + alloc_33.result, + indices=[c0_seg, c3_seg], + offsets=[0, 0, 0, 0], + sizes=[dv // mmul_n, lkp // mmul_k, mmul_k, mmul_n], + strides=[mmul_n, dv * mmul_k, dv, 1], + ) + + yield_([]) + + # (D) Output: L1→L2 gather for this q_iter + affine_map_tileq_seg = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_size_q), + ) + ], + ) + par_final = scf.ForallOp( + lower_bounds=[0], upper_bounds=[c_num_q_tiles], steps=[1] + ) + with InsertionPoint(par_final.body): + apply_final = affine_apply( + affine_map_tileq_seg, [par_final.induction_variables[0]] + ) + ChannelGet( + "L1ToL2Chan1", + alloc_5.result, + indices=[par_final.induction_variables[0], 0], + offsets=[apply_final, 0], + sizes=[tile_size_q, dv], + strides=[dv, 1], + ) + scf.InParallelOp() + + # (E) Output: L2→L3 transfer for this q_iter + ChannelPut("L2ToL3Chan1", alloc_5.result, indices=[head_idx]) + + yield_([]) + + # Unified herd: init + compute loop + cascade merge + output + unified_operands = ( + [alloc_6, up, sp, Gp, G_shared, QK_shared] + if enable_shared_buffers + else [alloc_6, up, sp, Gp] + ) + # Causal: pass counter as memref operand (no RTP/lock) + if causal: + unified_operands = unified_operands + [causal_counter] + + @herd( + name="herd_0", + sizes=[c_num_q_tiles, c_num_cascade], + operands=unified_operands, + link_with="attn_npu1.o", + ) + def unified_herd_body(*args): + arg22, arg23, arg24, arg25 = args[0], args[1], args[2], args[3] + if enable_shared_buffers: + arg26, arg27, arg28, arg29, arg30, arg31 = args[4:10] + counter_buf = args[10] if causal else None + else: + arg26, arg27, arg28, arg29 = args[4:8] + arg30 = arg31 = None + counter_buf = args[8] if causal else None + + if causal: + # Local counter. With lkp==dk (shared + # buffers), DMAs are infinite loops → no PDI reset + # → core loops continuously → counter persists. + # counter[0] = q_block_global + # counter[1] = boot flag (0=first, 1=initialized) + c0_ctr = ConstantOp(index_type, 0) + c1_ctr = ConstantOp(index_type, 1) + boot_flag = load(counter_buf, [c1_ctr]) + c0_i32_ctr = ConstantOp(i32, 0) + is_first = arith.CmpIOp( + arith.CmpIPredicate.eq, boot_flag, c0_i32_ctr + ) + if_first = scf.IfOp(is_first) + with InsertionPoint(if_first.then_block): + q_init = arith.IndexCastOp(i32, arg22) + store(q_init, counter_buf, [c0_ctr]) + c1_i32_f = ConstantOp(i32, 1) + store(c1_i32_f, counter_buf, [c1_ctr]) + scf.YieldOp([]) + + # === OUTER Q ITERATION LOOP (device-side) === + c_lq_iters_herd = ConstantOp(index_type, device_lq_iters) + c0_q = ConstantOp(index_type, 0) + c1_q = ConstantOp(index_type, 1) + + for q_iter in range_(c0_q, c_lq_iters_herd, c1_q): + + # === INIT PHASE === + if enable_shared_buffers: + ChannelGet("L2ToL1Chan2", arg31, indices=[arg22, arg23]) + CallOp([], "copy_tile", [arg31, arg26]) + else: + ChannelGet("L2ToL1Chan1", arg26, indices=[arg22, arg23]) + CallOp([], "zero_fill_gp_bf16", [arg29]) + CallOp([], "zero_fill_sp_bf16", [arg28]) + CallOp([], "neg_inf_fill_up_bf16", [arg27]) + + # === COMPUTE LOOP (on-device) === + c_chunks = ConstantOp(index_type, chunks_per_stage) + c0_loop = ConstantOp(index_type, 0) + c1_loop = ConstantOp(index_type, 1) + + for chunk_idx in range_(c0_loop, c_chunks, c1_loop): + if enable_shared_buffers: + G_l1 = CollapseShapeOp( + memref_lqp_lkp_l1, arg30, [[0, 1]] + ) + else: + G_alloc = AllocOp(memref_g_shared_l1, [], []) + G_l1 = CollapseShapeOp( + memref_lqp_lkp_l1, G_alloc.result, [[0, 1]] + ) + + CallOp([], "zero_fill_g_bf16", [G_l1]) + + if enable_shared_buffers: + ChannelGet("L2ToL1Chan2", arg31, indices=[arg22, arg23]) + CallOp([], "matmul_a_b_bf16", [arg26, arg31, G_l1]) + else: + QK_alloc = AllocOp(memref_dv_lkp_l1, [], []) + ChannelGet( + "L2ToL1Chan2", + QK_alloc.result, + indices=[arg22, arg23], + ) + CallOp( + [], + "matmul_a_b_bf16", + [arg26, QK_alloc.result, G_l1], + ) + + alloc_57 = AllocOp(memref_dv_lkp_l1, [], []) + ChannelGet( + "L2ToL1Chan3", alloc_57.result, indices=[arg22, arg23] + ) + + if causal: + # Local counter gives q_block_global. + # No RTP/herd lock — counter loaded from + # local L1 buffer. + c_cps = ConstantOp(index_type, chunks_per_stage) + kv_block = arith.AddIOp( + arith.MulIOp(arg23, c_cps).result, chunk_idx + ) + kv_i32 = arith.IndexCastOp(i32, kv_block.result) + c0_ctr_use = ConstantOp(index_type, 0) + q_i32 = load(counter_buf, [c0_ctr_use]) + CallOp([], "apply_causal_mask", [G_l1, q_i32, kv_i32]) + + c0_i32 = ConstantOp(i32, 0) + s_l1 = AllocOp(memref_lqp_l1, [], []) + r_l1 = AllocOp(memref_lqp_l1, [], []) + + # True fused softmax: max+exp+sum with f32 intermediates + CallOp( + [], + "fused_softmax", + [G_l1, arg27, s_l1.result, r_l1.result], + ) + CallOp([], "mul_r_gp", [r_l1.result, arg29]) + CallOp( + [], + "matmul_g_b_bf16", + [G_l1, alloc_57.result, arg29], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_l1.result, s_l1.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32, s_l1.result, arg28], + ) + + DeallocOp(s_l1) + DeallocOp(r_l1) + else: + c0_i32 = ConstantOp(i32, 0) + s_l1 = AllocOp(memref_lqp_l1, [], []) + r_l1 = AllocOp(memref_lqp_l1, [], []) + + # True fused softmax: max+exp+sum with f32 intermediates + CallOp( + [], + "fused_softmax", + [G_l1, arg27, s_l1.result, r_l1.result], + ) + CallOp([], "mul_r_gp", [r_l1.result, arg29]) + CallOp( + [], + "matmul_g_b_bf16", + [G_l1, alloc_57.result, arg29], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_l1.result, s_l1.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32, s_l1.result, arg28], + ) + + DeallocOp(s_l1) + DeallocOp(r_l1) + + DeallocOp(alloc_57) + + if not enable_shared_buffers: + DeallocOp(QK_alloc) + DeallocOp(G_alloc) + yield_([]) + + # === CASCADE MERGE === + c1_h = ConstantOp(index_type, 1) + r_l1_c = AllocOp(memref_lqp_l1, [], []) + + def get_gp_cascade(): + if enable_shared_buffers: + return arg30 + else: + return AllocOp(memref_lqp_dv_l1, [], []).result + + # affine.if for last cascade stage + affine_set_last = IntegerSet.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(-num_cascade_stages + 1), + ), + AffineSymbolExpr.get(0), + AffineExpr.get_add( + AffineConstantExpr.get(num_q_tiles - 1), + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(-1), + ), + ), + ], + [True, False, False], + ) + affine_if_last = affine.AffineIfOp( + affine_set_last, cond_operands=[arg22, arg23], has_else=True + ) + with InsertionPoint(affine_if_last.then_block): + subi = arith.SubIOp(arg23, c1_h) + ChannelPut("cascade", arg29, indices=[arg22, subi]) + ChannelPut("cascade", arg27, indices=[arg22, subi]) + ChannelPut("cascade", arg28, indices=[arg22, subi]) + affine.AffineYieldOp([]) + + with InsertionPoint(affine_if_last.else_block): + affine_set_middle = IntegerSet.get( + 0, + 2, + [ + AffineExpr.get_add( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(-1), + ), + AffineExpr.get_add( + AffineConstantExpr.get(num_cascade_stages - 2), + AffineExpr.get_mul( + AffineSymbolExpr.get(1), + AffineConstantExpr.get(-1), + ), + ), + AffineSymbolExpr.get(0), + AffineExpr.get_add( + AffineConstantExpr.get(num_q_tiles - 1), + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(-1), + ), + ), + ], + [False, False, False, False], + ) + affine_if_middle = affine.AffineIfOp( + affine_set_middle, + cond_operands=[arg22, arg23], + has_else=True, + ) + with InsertionPoint(affine_if_middle.then_block): + Gp_cascade = get_gp_cascade() + up_cascade = AllocOp(memref_lqp_l1, [], []) + sp_cascade = AllocOp(memref_lqp_l1, [], []) + ChannelGet( + "cascade", Gp_cascade, indices=[arg22, arg23] + ) + ChannelGet( + "cascade", up_cascade.result, indices=[arg22, arg23] + ) + ChannelGet( + "cascade", sp_cascade.result, indices=[arg22, arg23] + ) + up_B_saved = AllocOp(memref_lqp_l1, [], []) + c0_i32_m = ConstantOp(i32, 0) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_m, arg27, up_B_saved.result], + ) + CallOp( + [], "maximum_up_u_bf16", [up_cascade.result, arg27] + ) + CallOp( + [], + "exp_up_minus_u", + [up_cascade.result, arg27, r_l1_c.result], + ) + r_B = AllocOp(memref_lqp_l1, [], []) + CallOp( + [], + "exp_up_minus_u", + [up_B_saved.result, arg27, r_B.result], + ) + CallOp([], "mul_r_gp", [r_l1_c.result, Gp_cascade]) + CallOp([], "mul_r_gp", [r_B.result, arg29]) + CallOp([], "add_gp_g", [arg29, Gp_cascade]) + sp_temp = AllocOp(memref_lqp_l1, [], []) + CallOp([], "zero_fill_sp_bf16", [sp_temp.result]) + CallOp( + [], + "accum_sp_r_s", + [sp_cascade.result, r_l1_c.result, sp_temp.result], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_B.result, sp_temp.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_m, sp_temp.result, sp_cascade.result], + ) + subi2 = arith.SubIOp(arg23, c1_h) + ChannelPut( + "cascade", Gp_cascade, indices=[arg22, subi2] + ) + ChannelPut("cascade", arg27, indices=[arg22, subi2]) + ChannelPut( + "cascade", sp_cascade.result, indices=[arg22, subi2] + ) + DeallocOp(up_B_saved) + DeallocOp(r_B) + DeallocOp(sp_temp) + affine.AffineYieldOp([]) + + with InsertionPoint(affine_if_middle.else_block): + Gp_cascade2 = get_gp_cascade() + up_cascade2 = AllocOp(memref_lqp_l1, [], []) + sp_cascade2 = AllocOp(memref_lqp_l1, [], []) + ChannelGet( + "cascade", Gp_cascade2, indices=[arg22, arg23] + ) + ChannelGet( + "cascade", + up_cascade2.result, + indices=[arg22, arg23], + ) + ChannelGet( + "cascade", + sp_cascade2.result, + indices=[arg22, arg23], + ) + up_B_saved2 = AllocOp(memref_lqp_l1, [], []) + c0_i32_f = ConstantOp(i32, 0) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_f, arg27, up_B_saved2.result], + ) + CallOp( + [], "maximum_up_u_bf16", [up_cascade2.result, arg27] + ) + CallOp( + [], + "exp_up_minus_u", + [up_cascade2.result, arg27, r_l1_c.result], + ) + r_B2 = AllocOp(memref_lqp_l1, [], []) + CallOp( + [], + "exp_up_minus_u", + [up_B_saved2.result, arg27, r_B2.result], + ) + CallOp([], "mul_r_gp", [r_l1_c.result, Gp_cascade2]) + CallOp([], "mul_r_gp", [r_B2.result, arg29]) + CallOp([], "add_gp_g", [arg29, Gp_cascade2]) + sp_temp2 = AllocOp(memref_lqp_l1, [], []) + CallOp([], "zero_fill_sp_bf16", [sp_temp2.result]) + CallOp( + [], + "accum_sp_r_s", + [ + sp_cascade2.result, + r_l1_c.result, + sp_temp2.result, + ], + ) + CallOp( + [], + "accum_sp_r_s", + [arg28, r_B2.result, sp_temp2.result], + ) + CallOp( + [], + "vector_copy_32elems", + [c0_i32_f, sp_temp2.result, sp_cascade2.result], + ) + CallOp( + [], "div_gp_sp", [sp_cascade2.result, Gp_cascade2] + ) + DeallocOp(up_B_saved2) + DeallocOp(r_B2) + DeallocOp(sp_temp2) + ChannelPut( + "L1ToL2Chan1", + Gp_cascade2, + indices=[arg22, 0], + offsets=[0, 0, 0, 0], + sizes=[ + tile_size_q // mmul_n, + mmul_m, + dv // mmul_m, + mmul_n, + ], + strides=[ + mmul_m * mmul_n, + mmul_n, + tile_size_q * mmul_n, + 1, + ], + ) + affine.AffineYieldOp([]) + affine.AffineYieldOp([]) + + # Increment q_block counter for next launch iteration + if causal: + c0_ci = ConstantOp(index_type, 0) + c2_ci = ConstantOp(index_type, 2) + c1_i32_ci = ConstantOp(i32, 1) + # Increment head counter + head_cur = load(counter_buf, [c2_ci]) + head_next = arith.AddIOp(head_cur, c1_i32_ci) + total_heads_i32 = ConstantOp(i32, num_head_groups) + wrapped = arith.CmpIOp( + arith.CmpIPredicate.sge, + head_next, + total_heads_i32, + ) + if_wrap = scf.IfOp(wrapped) + with InsertionPoint(if_wrap.then_block): + # All heads done: increment q_block, reset head + q_cur = load(counter_buf, [c0_ci]) + c_nqt_i32 = ConstantOp(i32, num_q_tiles) + q_next = arith.AddIOp(q_cur, c_nqt_i32) + store(q_next, counter_buf, [c0_ci]) + c0_i32_ci = ConstantOp(i32, 0) + store(c0_i32_ci, counter_buf, [c2_ci]) + scf.YieldOp([]) + if_wrap_else = scf.IfOp( + arith.CmpIOp( + arith.CmpIPredicate.slt, + head_next, + total_heads_i32, + ) + ) + with InsertionPoint(if_wrap_else.then_block): + store(head_next, counter_buf, [c2_ci]) + scf.YieldOp([]) + + yield_([]) # end of q_iter loop + + # Output channel gets are inside the combined Q/K/V/output loop above + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="attn.py") + parser.add_argument("-p", "--print-module-only", action="store_true") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument( + "--mlir-file", + type=str, + default=None, + help="Path to external MLIR file to compile (instead of generating)", + ) + parser.add_argument( + "--lk", type=int, default=12288, help="Total sequence length for K/V matrices" + ) + parser.add_argument( + "--lkp", type=int, default=96, help="Chunk size for K/V processing" + ) + parser.add_argument( + "--lq", type=int, default=512, help="Total sequence length for Q matrix" + ) + parser.add_argument( + "--lqp", + type=int, + default=128, + help="Chunk size for Q processing per launch iteration", + ) + parser.add_argument("--dk", type=int, default=64, help="Key dimension") + parser.add_argument("--dv", type=int, default=64, help="Value dimension") + parser.add_argument( + "--num-heads", type=int, default=12, help="Number of Q attention heads" + ) + parser.add_argument( + "--num-kv-heads", + type=int, + default=None, + help="Number of K/V heads (default: num_heads for MHA, set < num_heads for GQA)", + ) + parser.add_argument( + "--compile-mode", + type=str, + default="run", + choices=["run", "compile"], + help="Compilation mode: run (default, compile + test), compile (generate binary only)", + ) + parser.add_argument( + "--causal", + action="store_true", + help="Enable causal masking (autoregressive attention)", + ) + parser.add_argument( + "--val-range", + type=float, + default=3.0, + help="Input value range for random test data (default: 3.0)", + ) + args = parser.parse_args() + + lk, lkp, lq, lqp, dk, dv = args.lk, args.lkp, args.lq, args.lqp, args.dk, args.dv + causal = args.causal + num_heads = args.num_heads + num_kv_heads = args.num_kv_heads if args.num_kv_heads is not None else num_heads + + if num_kv_heads <= 0: + raise ValueError(f"num_kv_heads must be positive, got {num_kv_heads}") + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + ) + + if args.mlir_file: + with open(args.mlir_file, "r") as f: + mlir_source = f.read() + with Context() as ctx, Location.unknown(): + registry = DialectRegistry() + air.dialects.air.register_dialect(registry) + ctx.append_dialect_registry(registry) + ctx.load_all_available_dialects() + mlir_module = Module.parse(mlir_source) + print(f"Loaded MLIR module from: {args.mlir_file}") + else: + mlir_module = build_module( + lk=lk, + lkp=lkp, + lq=lq, + lqp=lqp, + dk=dk, + dv=dv, + num_q_tiles=4, + num_cascade_stages=4, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + causal=causal, + ) + + if args.print_module_only: + print(mlir_module) + exit(0) + + from air.backend.xrt_runner import XRTRunner, type_mapper + from air.backend.xrt import XRTBackend + from air.extras import types as extrasT + from ml_dtypes import bfloat16 + + INPUT_DATATYPE = OUTPUT_DATATYPE = bfloat16 + VM_ACC_DATATYPE = np.float32 + + gqa_group_size = num_heads // num_kv_heads + + rng = np.random.default_rng(42) + val_range = args.val_range + input_q = rng.uniform(0, val_range, (num_heads, lq, dk)).astype(INPUT_DATATYPE) + input_k = rng.uniform(0, val_range, (num_kv_heads, lk, dk)).astype(INPUT_DATATYPE) + input_v = rng.uniform(0, val_range, (num_kv_heads, lk, dv)).astype(INPUT_DATATYPE) + input_m = np.zeros((num_heads, lq, lk), dtype=INPUT_DATATYPE) + + inv_sqrt_dk = 1.0 / sqrt(dk) + + def sdpa_golden(Q, K, V, scale, causal_mask=False): + """Standard scaled dot-product attention in f32.""" + scores = (Q.astype(np.float32) @ K.astype(np.float32).T) * scale + if causal_mask: + mask = np.triu(np.ones(scores.shape, dtype=bool), k=1) + scores = np.where(mask, -1e9, scores) + m = np.max(scores, axis=-1, keepdims=True) + exp_s = np.exp(scores - m) + P = exp_s / np.sum(exp_s, axis=-1, keepdims=True) + return (P @ V.astype(np.float32)).astype(OUTPUT_DATATYPE) + + sdpa_output = np.zeros((num_heads, lq, dv), dtype=OUTPUT_DATATYPE) + for h in range(num_heads): + kv_h = h // gqa_group_size + sdpa_output[h] = sdpa_golden( + input_q[h], + input_k[kv_h], + input_v[kv_h], + inv_sqrt_dk, + causal_mask=causal, + ) + + enable_shared_buffers_main = lkp == dk + # Causal mode requires while-true loop: the herd RTP mechanism needs the + # core to loop back and re-acquire the herd lock for each launch iteration. + # Without the loop, the core exits after one iteration and subsequent + # RTP writes / lock releases go to a dead core. + omit_loop = False if causal else not enable_shared_buffers_main + runner = XRTRunner( + omit_while_true_loop=omit_loop, + omit_pingpong="all", + verbose=args.verbose, + runtime_loop_tiling_sizes=[1, 1], + output_format="xclbin", + instance_name="attention_bf16", + target_device="npu1", + ) + + if args.compile_mode == "run": + exit( + runner.run_test( + mlir_module, + inputs=[input_q, input_k, input_v, input_m], + expected_outputs=[sdpa_output], + atol=0.15, + rtol=0.04, + max_mismatch_percentage=2, + ) + ) + elif args.compile_mode == "compile": + backend = XRTBackend( + omit_while_true_loop=omit_loop, + omit_pingpong="all", + verbose=args.verbose, + runtime_loop_tiling_sizes=[1, 1], + output_format="xclbin", + instance_name="attention_bf16", + ) + module_function = backend.compile(mlir_module) + print("Compilation complete. Generated xclbin binary") diff --git a/programming_examples/flash_attention/packet_switched/attn_pkt.cc b/programming_examples/flash_attention/packet_switched/attn_pkt.cc new file mode 100644 index 000000000..5d3555fd8 --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/attn_pkt.cc @@ -0,0 +1,680 @@ +//===- attn.cc --------------------------------------------------*- C++ -*-===// +// +// SPDX-License-Identifier: MIT +// Copyright (C) 2025, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +#define NOCPP + +#include +#include +#include +#include + +#define REL_WRITE 0 +#define REL_READ 1 + +#include + +#include "zero.cc" + +// Default values if not provided by Makefile +#ifndef lqp +#define lqp 32 +#endif + +#ifndef lkp +#define lkp 96 +#endif + +#ifndef dk +#define dk 64 +#endif + +#ifndef dv +#define dv 64 +#endif + +// Column-major B matmul with compile-time transpose control. +// transpose_b: true = apply aie::transpose before mac (K DMA: inner [n_in, +// k_in]) +// false = load B as-is, hardware mul_8x8_8x8T transposes (V DMA: +// inner [k_in, n_in]) +// A and C are always column-major tiled. +template +static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA, + const T_in *__restrict pB, + T_out *__restrict pC) { + + using MMUL = aie::mmul; + + event0(); + + for (unsigned z = 0; z < rowA; z += 2) + chess_prepare_for_pipelining chess_loop_range(2, ) { + T_out *__restrict pC1 = pC + (z)*MMUL::size_C; + T_out *__restrict pC2 = pC + ((z + 1)) * MMUL::size_C; + + for (unsigned j = 0; j < colB; j += 2) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + const T_in *__restrict pA1 = pA + (z)*MMUL::size_A; + const T_in *__restrict pA2 = pA + ((z + 1)) * MMUL::size_A; + const T_in *__restrict pB1 = pB + (j)*colA * MMUL::size_B; + const T_in *__restrict pB2 = pB + (j + 1) * colA * MMUL::size_B; + + aie::vector acc_C00 = + aie::load_v(pC1); + aie::vector acc_C01 = + aie::load_v(pC1 + MMUL::size_C * rowA); + aie::vector acc_C10 = + aie::load_v(pC2); + aie::vector acc_C11 = + aie::load_v(pC2 + MMUL::size_C * rowA); + + MMUL C00(acc_C00); + MMUL C01(acc_C01); + MMUL C10(acc_C10); + MMUL C11(acc_C11); + + for (unsigned i = 0; i < colA; ++i) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + aie::vector A0 = + aie::load_v(pA1); + pA1 += rowA * MMUL::size_A; + aie::vector A1 = + aie::load_v(pA2); + pA2 += rowA * MMUL::size_A; + + aie::vector B0, B1; + if constexpr (transpose_b) { + // K DMA inner layout is [n_in, k_in] — need software transpose + // to [k_in, n_in] before hardware mul_8x8_8x8T. + B0 = aie::transpose(aie::load_v(pB1), t, s); + B1 = aie::transpose(aie::load_v(pB2), t, s); + } else { + // V DMA inner layout is [k_in, n_in] — already correct for + // hardware mul_8x8_8x8T, no software transpose needed. + B0 = aie::load_v(pB1); + B1 = aie::load_v(pB2); + } + pB1 += MMUL::size_B; + pB2 += MMUL::size_B; + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + } + + aie::store_v(pC1, C00.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C01.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC2, C10.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C11.template to_vector()); + pC2 += MMUL::size_C * rowA; + } + } + + event1(); +} + +// bf16 MatMul kernel with bf16 outputs. +// transpose_b: controls whether B blocks are software-transposed before mac. +template +static inline void +matmul_vectorized_8x8x8_bf16_bf16(const bfloat16 *__restrict pA, + const bfloat16 *__restrict pB, + bfloat16 *__restrict pC) { + constexpr int r = 8; + constexpr int s = 8; + constexpr int t = 8; + static_assert(m % (2 * r) == 0); // 'm' dimension + static_assert(k % s == 0); // 'k' dimension + static_assert(n % (2 * t) == 0); // 'n' dimension + + return matmul_vectorized_2x2_mmul(pA, pB, pC); +} + +// Combined scale: log2e / sqrt(dk). IRON uses this to apply 1/sqrt(dk) +// inside softmax with accfloat precision, avoiding bf16 truncation of Q. +// dk is a macro from the Makefile (-Ddk=64). +#include +#define log2e (1.44269504089 / constexpr_sqrt_dk) +constexpr double constexpr_sqrt_dk = 8.0; // sqrt(64) — matches dk=64 + +__attribute__((always_inline)) v8bfloat16 getExpBf16(v8bfloat16 x) { + + constexpr int VecLen = 8; + + // Calculate the e^(x) function as 2^(log2e * x) + aie::vector input_bf16 = x; + aie::accum exp_in; + aie::vector exp_val; + aie::vector log2e_vec = + aie::broadcast(log2e); + + exp_in = aie::mul(input_bf16, log2e_vec); + exp_val = aie::exp2(exp_in.to_vector()); + return exp_val; +} + +extern "C" { + +// Set rounding mode at the start of every extern C function. +// IRON sets conv_even in every softmax function; without this, +// softmax intermediates use the system default rounding mode, +// causing ~44% errors at val_range=4 due to rounding noise +// amplified by softmax's peaked distribution. +#ifdef ROUND_CONV_EVEN +#define SET_ROUNDING() ::aie::set_rounding(::aie::rounding_mode::conv_even) +#else +#define SET_ROUNDING() /* no-op */ +#endif + +// Copy tile_size_q×dk elements from src to dst (single-pass vector copy) +void copy_tile(bfloat16 *src, bfloat16 *dst) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int num_elems = lqp * dk; + bfloat16 *__restrict ps = src; + bfloat16 *__restrict pd = dst; + for (unsigned j = 0; j < num_elems / VecLen; j++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = aie::load_v(ps); + aie::store_v(pd, v); + ps += VecLen; + pd += VecLen; + } +} + +void matmul_a_b_bf16(bfloat16 *a_in, bfloat16 *b_in, bfloat16 *out) { + SET_ROUNDING(); + // Buffer shapes: + // A: [lqp, dk] = [32, 64] + // B: [lkp, dk] = [96, 64] (K row-major, aie::transpose per block) + // Out: [lqp, lkp] = [32, 96] + matmul_vectorized_8x8x8_bf16_bf16(a_in, b_in, out); +} + +void matmul_g_b_bf16(bfloat16 *g_in, bfloat16 *b_in, bfloat16 *out) { + SET_ROUNDING(); + // Buffer shapes: + // G: [lqp, lkp] = [32, 96] + // B: [lkp, dv] = [96, 64] + // Out: [lqp, dv] = [32, 64] + // G@V: V DMA inner layout is [k_in, n_in], so NO software transpose needed. + // The hardware mul_8x8_8x8T already transposes B internally. + matmul_vectorized_8x8x8_bf16_bf16( + g_in, b_in, out); +} + +void zero_fill_gp_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, dv] = [32, 64] + zero_vectorized(c_out); +} + +void zero_fill_sp_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, 1] = [32, 1] + zero_vectorized(c_out); +} + +void zero_fill_g_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, lkp] = [32, 96] + zero_vectorized(c_out); +} + +void neg_inf_fill_up_bf16(bfloat16 *c_out) { + SET_ROUNDING(); + // Buffer shape: [lqp, 1] = [32, 1] + neg_inf_vectorized(c_out); +} + +void max_g_bf16(bfloat16 *in, bfloat16 *out) { + SET_ROUNDING(); + // u = np.max(G, axis=-1, keepdims=True) + // G is in column-major 8x8 tiled layout. + // Each block is 64 contiguous elements (8 rows × 8 cols). + // VecLen=32 reads 4 rows at once (half a block). + constexpr int VecLen = 32; + constexpr int BlockSize = 64; // 8×8 block + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = + lqp * ColsPerBlock; // stride between column blocks + + // Use bf16 lowest (0xff7f) instead of -inf (0xff80) as initial max value. + // For fully-masked rows (all -inf), max returns bf16_lowest > -inf, + // avoiding NaN in exp(G - u) where G=-inf and u would be -inf. + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + + bfloat16 *__restrict pOut = out; + for (int rb = 0; rb < row_blocks; rb++) { + // Process 4 rows at a time (half block = 32 elements) + for (int half = 0; half < 2; half++) { + aie::vector max_vec = + aie::broadcast(lowest_val); + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = + aie::load_v(in + base + cb * block_stride); + max_vec = aie::max(max_vec, v); + } + // Extract per-row max from 32-wide vector (4 rows × 8 cols) + aie::vector r0 = max_vec.extract<8>(0); + aie::vector r1 = max_vec.extract<8>(1); + aie::vector r2 = max_vec.extract<8>(2); + aie::vector r3 = max_vec.extract<8>(3); + pOut[half * 4 + 0] = aie::reduce_max(r0); + pOut[half * 4 + 1] = aie::reduce_max(r1); + pOut[half * 4 + 2] = aie::reduce_max(r2); + pOut[half * 4 + 3] = aie::reduce_max(r3); + } + pOut += RowsPerBlock; + } +} + +void maximum_up_u_bf16(bfloat16 *up, bfloat16 *u) { + SET_ROUNDING(); + // u = np.maximum(u, up) + // Buffer shape: + // up: [lqp, 1] = [32, 1] + // u: [lqp, 1] = [32, 1] + constexpr int VecLen = 32; + constexpr int num_elems = lqp; + bfloat16 *__restrict pu = u; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector up_temp = aie::load_v(up + i); + aie::vector u_temp = aie::load_v(pu); + u_temp = aie::max(up_temp, u_temp); + aie::store_v(pu, u_temp); + pu += VecLen; + } +} + +void exp_g_minus_u(bfloat16 *u, bfloat16 *g) { + SET_ROUNDING(); + // G = exp(G - u) in-place. G is column-major 8×8 tiled. + // VecLen=32 processes 4 rows at once (half a block). + // exp2 native width is 16, so split 30→2×16 for exp. + // With bf16 lowest (not -inf), lowest - lowest = 0 (not NaN). + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + aie::vector log2e_vec16 = + aie::broadcast((bfloat16)log2e); + aie::vector lowest_vec = + aie::broadcast(lowest_val); + + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + // Build 32-wide u vector: 4 rows × 8 cols, each row broadcast + int row_start = rb * RowsPerBlock + half * 4; + aie::vector u0 = aie::broadcast(u[row_start]); + aie::vector u1 = + aie::broadcast(u[row_start + 1]); + aie::vector u2 = + aie::broadcast(u[row_start + 2]); + aie::vector u3 = + aie::broadcast(u[row_start + 3]); + aie::vector u_vec; + u_vec.insert(0, u0); + u_vec.insert(1, u1); + u_vec.insert(2, u2); + u_vec.insert(3, u3); + + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(g + off); + v = aie::sub(v, u_vec); + v = aie::max(v, lowest_vec); + // exp2(log2e * v) — split into 2×16 for native exp2 width + aie::vector lo = v.extract<16>(0); + aie::vector hi = v.extract<16>(1); + lo = + aie::exp2(aie::mul(lo, log2e_vec16).to_vector()); + hi = + aie::exp2(aie::mul(hi, log2e_vec16).to_vector()); + v.insert(0, lo); + v.insert(1, hi); + aie::store_v(g + off, v); + } + } + } +} + +void exp_up_minus_u(bfloat16 *up, bfloat16 *u, bfloat16 *r) { + SET_ROUNDING(); + // r = exp(up - u) — VecLen=16 to match exp2 native width + // With bf16 lowest (not -inf), lowest - lowest = 0 (not NaN). + constexpr int VecLen = 16; + constexpr int num_elems = lqp; + uint16_t lowest_u16 = (uint16_t)0xff7f; + bfloat16 lowest_val = *(bfloat16 *)&lowest_u16; + aie::vector lowest_vec = + aie::broadcast(lowest_val); + bfloat16 *__restrict pr = r; + bfloat16 *__restrict pu = u; + bfloat16 *__restrict pup = up; + aie::vector log2e_vec = + aie::broadcast((bfloat16)log2e); + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector uTemp = aie::load_v(pu); + aie::vector upTemp = aie::load_v(pup); + aie::vector diff = aie::sub(upTemp, uTemp); + // Clamp extreme negative values + diff = aie::max(diff, lowest_vec); + aie::vector exp_val = + aie::exp2(aie::mul(diff, log2e_vec).to_vector()); + aie::store_v(pr, exp_val); + pr += VecLen; + pu += VecLen; + pup += VecLen; + } +} + +void mul_r_gp(bfloat16 *r, bfloat16 *gp) { + SET_ROUNDING(); + // Gp = Gp * r (per-row scaling) + // Buffer shape: Gp: [lqp, dv], r: [lqp, 1] + // Layout: column-major 8×8 block tiled (same as matmul output). + // block(col_blk, row_blk) at offset col_blk * (lqp * 8) + row_blk * 64, + // element within block at row_in * 8 + col_in. + // VecLen=32 reads 4 rows × 8 cols (half a block). + constexpr int VecLen = 32; + constexpr int BlockSize = 64; // 8×8 block + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = dv / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = + lqp * ColsPerBlock; // stride between column blocks + + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + // Build 32-wide r vector: 4 rows × 8 cols, each row's r broadcast to 8 + int row_start = rb * RowsPerBlock + half * 4; + aie::vector r0 = aie::broadcast(r[row_start]); + aie::vector r1 = + aie::broadcast(r[row_start + 1]); + aie::vector r2 = + aie::broadcast(r[row_start + 2]); + aie::vector r3 = + aie::broadcast(r[row_start + 3]); + aie::vector r_vec; + r_vec.insert(0, r0); + r_vec.insert(1, r1); + r_vec.insert(2, r2); + r_vec.insert(3, r3); + + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(gp + off); + aie::accum acc = aie::mul(v, r_vec); + aie::store_v(gp + off, acc.to_vector()); + } + } + } +} + +void sum_g(bfloat16 *g, bfloat16 *s) { + SET_ROUNDING(); + // s = sum(G, axis=-1, keepdims=True) + // G is column-major 8×8 tiled. VecLen=32 loads 4 rows at once. + constexpr int VecLen = 32; + constexpr int BlockSize = 64; + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = lkp / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = lqp * ColsPerBlock; + + bfloat16 *__restrict ps = s; + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + // Accumulate sum across column blocks for 4 rows + aie::accum sum_acc = aie::zeros(); + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + aie::vector v = + aie::load_v(g + base + cb * block_stride); + sum_acc = aie::add(sum_acc, v); + } + // Reduce each 8-element row slice to get per-row sum. + // Use f32 for reduce_add to preserve precision (IRON does this). + aie::vector sum_v = sum_acc.to_vector(); + aie::vector r0 = sum_v.extract<8>(0); + aie::vector r1 = sum_v.extract<8>(1); + aie::vector r2 = sum_v.extract<8>(2); + aie::vector r3 = sum_v.extract<8>(3); + ps[half * 4 + 0] = (bfloat16)aie::reduce_add(r0); + ps[half * 4 + 1] = (bfloat16)aie::reduce_add(r1); + ps[half * 4 + 2] = (bfloat16)aie::reduce_add(r2); + ps[half * 4 + 3] = (bfloat16)aie::reduce_add(r3); + } + ps += RowsPerBlock; + } +} + +void accum_sp_r_s(bfloat16 *sp, bfloat16 *r, bfloat16 *s) { + SET_ROUNDING(); + // s += sp * r + // Buffer shape: + // sp: [lqp, 1] = [32, 1] + // r: [lqp, 1] = [32, 1] + // s: [lqp, 1] = [32, 1] + constexpr int VecLen = 32; + constexpr int num_elems = lqp; + bfloat16 *__restrict pr = r; + bfloat16 *__restrict ps = s; + bfloat16 *__restrict psp = sp; + for (int i = 0; i < num_elems; i += VecLen) { + aie::vector rTemp = aie::load_v(pr); + aie::vector spTemp = aie::load_v(psp); + aie::accum accTemp = aie::mul(rTemp, spTemp); + accTemp = aie::add(accTemp, aie::load_v(ps)); + aie::vector sTemp = to_v32bfloat16(accTemp); + aie::store_v(ps, sTemp); + pr += VecLen; + ps += VecLen; + psp += VecLen; + } +} + +void vector_copy_32elems(const int offset, const bfloat16 *__restrict inputs, + bfloat16 *__restrict outputs) { + constexpr int VecLen = 32; + constexpr int num_elems = lqp; + const bfloat16 *__restrict pIn = inputs; + bfloat16 *__restrict pOut = outputs + offset; + for (unsigned j = 0; j < num_elems / VecLen; j++) { + aie::vector vec = aie::load_v(pIn); + pIn += VecLen; + aie::store_v(pOut, vec); + pOut += VecLen; + } +} + +void div_gp_sp(bfloat16 *sp, bfloat16 *gp) { + SET_ROUNDING(); + // Gp = Gp / sp (per-row normalization) + // Buffer shape: Gp: [lqp, dv], sp: [lqp, 1] + // Layout: column-major 8×8 block tiled (same as matmul output). + // block(col_blk, row_blk) at offset col_blk * (lqp * 8) + row_blk * 64, + // element within block at row_in * 8 + col_in. + // VecLen=32 reads 4 rows × 8 cols (half a block). + constexpr int VecLen = 32; + constexpr int BlockSize = 64; // 8×8 block + constexpr int ColsPerBlock = 8; + constexpr int RowsPerBlock = 8; + constexpr int col_blocks = dv / ColsPerBlock; + constexpr int row_blocks = lqp / RowsPerBlock; + constexpr int block_stride = + lqp * ColsPerBlock; // stride between column blocks + + for (int rb = 0; rb < row_blocks; rb++) { + for (int half = 0; half < 2; half++) { + // Build 32-wide 1/sp vector: 4 rows × 8 cols, each row's inv(sp) + // broadcast + int row_start = rb * RowsPerBlock + half * 4; + aie::vector sp0 = aie::broadcast(sp[row_start]); + aie::vector sp1 = + aie::broadcast(sp[row_start + 1]); + aie::vector sp2 = + aie::broadcast(sp[row_start + 2]); + aie::vector sp3 = + aie::broadcast(sp[row_start + 3]); + aie::vector sp_vec; + sp_vec.insert(0, sp0); + sp_vec.insert(1, sp1); + sp_vec.insert(2, sp2); + sp_vec.insert(3, sp3); + aie::vector sp_inv = aie::inv(sp_vec); + + int base = rb * BlockSize + half * VecLen; + for (int cb = 0; cb < col_blocks; cb++) + chess_prepare_for_pipelining chess_loop_range(8, ) { + int off = base + cb * block_stride; + aie::vector v = aie::load_v(gp + off); + aie::accum acc = aie::mul(v, sp_inv); + aie::store_v(gp + off, acc.to_vector()); + } + } + } +} + +// Fused softmax: delegates to existing optimized VecLen=32 kernels. +// On return: up=new_max, sp=sum(exp(G)), r=rescale_factor, G=exp(G-max). +void fused_softmax(bfloat16 *g, bfloat16 *up, bfloat16 *sp, bfloat16 *r) { + SET_ROUNDING(); + max_g_bf16(g, r); + maximum_up_u_bf16(up, r); + exp_g_minus_u(r, g); + exp_up_minus_u(up, r, sp); + vector_copy_32elems(0, r, up); + vector_copy_32elems(0, sp, r); + sum_g(g, sp); +} + +void add_gp_g(bfloat16 *gp, bfloat16 *g) { + SET_ROUNDING(); + constexpr int VecLen = 32; + constexpr int num_elems = lqp * dv; + bfloat16 *__restrict gp_ptr = gp; + bfloat16 *__restrict g_ptr = g; + for (unsigned j = 0; j < num_elems / VecLen; j++) { + aie::vector gp_vec = aie::load_v(gp_ptr); + aie::vector g_vec = aie::load_v(g_ptr); + aie::accum acc(gp_vec); + acc = aie::add(acc, g_vec); + aie::store_v(g_ptr, acc.to_vector()); + gp_ptr += VecLen; + g_ptr += VecLen; + } +} + +// Apply causal mask to QK scores in-place. Sets elements where +// global_kv_col > global_q_row to -inf in the tiled G buffer. +// G is in column-major 8×8 tiled layout: block(col_blk, row_blk) at +// offset col_blk * (lqp * 8) + row_blk * 64, element within block at +// row_in_blk * 8 + col_in_blk. +void apply_causal_mask(bfloat16 *g, int32_t q_block_idx, int32_t kv_block_idx) { + SET_ROUNDING(); + uint16_t neg_inf_u16 = (uint16_t)0xff80; + bfloat16 neg_inf_val = *(bfloat16 *)&neg_inf_u16; + + // 1. Block above diagonal: all masked -> fill with -inf + if (kv_block_idx > q_block_idx) { + constexpr int VecLen = 32; + aie::vector neg_inf_vec = + aie::broadcast(neg_inf_val); + bfloat16 *p = g; + for (int i = 0; i < lqp * lkp; i += VecLen) { + aie::store_v(p, neg_inf_vec); + p += VecLen; + } + return; + } + + // 2. Block below diagonal: no masking needed + if (kv_block_idx < q_block_idx) { + return; + } + + // 3. Diagonal block (kv_block_idx == q_block_idx): + // Read-modify-write ALL 8-element row slices for EVERY row. + // For unmasked blocks: read and write back unchanged. + // For masked blocks: write mask value. + // For partial blocks: read, select, write back. + // This ensures EVERY position goes through a vector load+store cycle. + constexpr int BlkDim = 8; + aie::vector mask_vec = + aie::broadcast(neg_inf_val); + + for (int row = 0; row < lqp; row++) { + int mask_start = row + 1; + int row_blk = row / BlkDim; + int row_in = row % BlkDim; + + for (int col_blk = 0; col_blk < lkp / BlkDim; col_blk++) { + int col_start = col_blk * BlkDim; + int off = col_blk * (lqp * BlkDim) + row_blk * (BlkDim * BlkDim) + + row_in * BlkDim; + + aie::vector orig = aie::load_v(g + off); + + if (mask_start >= lkp) { + // Last row or beyond: no masking, write back unchanged + aie::store_v(g + off, orig); + } else if (col_start >= mask_start) { + // Entire block masked + aie::store_v(g + off, mask_vec); + } else if (col_start + BlkDim > mask_start) { + // Partial block + uint32_t sel_bits = 0; + for (int c = 0; c < BlkDim; c++) { + if (col_start + c >= mask_start) { + sel_bits |= (1u << c); + } + } + aie::mask sel(sel_bits); + aie::store_v(g + off, aie::select(orig, mask_vec, sel)); + } else { + // Unmasked block: write back unchanged + aie::store_v(g + off, orig); + } + } + } +} + +} // extern "C" diff --git a/programming_examples/flash_attention/packet_switched/run_npu1_makefile_peano.lit b/programming_examples/flash_attention/packet_switched/run_npu1_makefile_peano.lit new file mode 100644 index 000000000..4f87a29ad --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/run_npu1_makefile_peano.lit @@ -0,0 +1,10 @@ +// (c) Copyright 2025 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu1, peano +// +// RUN: mkdir -p %t +// RUN: cd %t +// RUN: make -f %S/Makefile clean PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR +// RUN: make -f %S/Makefile run-npu1 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/flash_attention/packet_switched/run_npu2_makefile_peano.lit b/programming_examples/flash_attention/packet_switched/run_npu2_makefile_peano.lit new file mode 100644 index 000000000..c64626907 --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/run_npu2_makefile_peano.lit @@ -0,0 +1,10 @@ +// (c) Copyright 2025 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu2, peano +// +// RUN: mkdir -p %t +// RUN: cd %t +// RUN: make -f %S/Makefile clean PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR +// RUN: make -f %S/Makefile run PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/flash_attention/packet_switched/zero.cc b/programming_examples/flash_attention/packet_switched/zero.cc new file mode 100644 index 000000000..0df2d1ca1 --- /dev/null +++ b/programming_examples/flash_attention/packet_switched/zero.cc @@ -0,0 +1,47 @@ +//===- zero.cc --------------------------------------------------*- C++ -*-===// +// +// SPDX-License-Identifier: MIT +// Copyright (C) 2025, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +#ifndef ZERO_CC +#define ZERO_CC + +#include +#include +#include +#include + +template +void zero_vectorized(T *__restrict c) { + const aie::vector zeros = aie::zeros(); + const T *__restrict c_end = c + M * N; + for (; c + r < c_end; c += r) { + aie::store_v(c, zeros); + } + // Do a scalar write for any remainder not divisible by vector instruction + // size r + for (; c < c_end; c++) { + *c = 0; + } +} + +template +void neg_inf_vectorized(T *__restrict c) { + // Use bf16 lowest (0xff7f ≈ -3.39e38) instead of -inf (0xff80) to avoid NaN + // on AIE2P: max(NaN, -inf) returns NaN, but max(NaN, lowest) also returns NaN + // — the real fix is that lowest - lowest = 0, not NaN, avoiding the issue. + uint16_t lowest_u16 = (uint16_t)0xff7f; + T *T_lowest = (T *)&lowest_u16; + const aie::vector lowest_vec = aie::broadcast(*T_lowest); + const T *__restrict c_end = c + M * N; + for (; c + r < c_end; c += r) { + aie::store_v(c, lowest_vec); + } + for (; c < c_end; c++) { + *c = *T_lowest; + } +} + +#endif