Skip to content

Commit df4f24d

Browse files
erwei-xilinxclaude
andcommitted
Add V cache write-back with interleaved KV cache layout
Extend the KV cache prefill design to write both K and V caches to DDR during flash attention computation. Uses a single CacheWB channel with an interleaved KV cache layout [K_c0, V_c0, K_c1, V_c1, ...] where both K and V data are staged through kwb_buf before DMA transfer. Key design choices: - Single CacheWB channel avoids shim S2MM channel exhaustion (no packet switching needed) - Shared kwb_buf staging buffer prevents DMA race between CacheWB read and V2L1 write on the v buffer - scf.for loop in launch body enables compiler BD folding, preventing BD exhaustion at large sequence lengths (tested up to 12h x 4096) Compiler changes (AIRToAIEPass.cpp): - Fix packet BD attribute lookup for L1-to-L3 dma_packet channels (getExistingPacketFlowOpFromDevice searches both flow maps) - Place outbound MM2S lock acquire before channel put and release after channel put, enabling interleaved lock pattern for multiple puts sharing the same staging buffer Performance: 12 heads x 4096 seq_len achieves 2460 peak GFLOPS with zero overhead vs K-only writeback. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bc95ac7 commit df4f24d

4 files changed

Lines changed: 344 additions & 149 deletions

File tree

mlir/lib/Conversion/AIRToAIEPass.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4685,12 +4685,7 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
46854685
builder.setInsertionPoint(alloc.getDefiningOp());
46864686
else if (!tileInbound.value() &&
46874687
isa<AIE::BufferOp>(alloc.getDefiningOp())) {
4688-
auto br = dyn_cast_if_present<cf::BranchOp>(
4689-
memcpyOpIf->getBlock()->getTerminator());
4690-
if (br)
4691-
builder.setInsertionPointToStart(br.getDest());
4692-
else
4693-
builder.setInsertionPointToStart(memcpyOpIf->getBlock());
4688+
builder.setInsertionPoint(memcpyOpIf);
46944689
} else
46954690
builder.setInsertionPoint(memcpyOpIf);
46964691

@@ -4701,8 +4696,14 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
47014696
lockAqValue);
47024697

47034698
// Try to find the end of lifetime for the data copied by memcpyOpIf, and
4704-
// put the unlock.
4705-
if (auto nextWriter = findNextDmaWriteOp(memcpyOpIf, alloc)) {
4699+
// put the unlock. For outbound puts from AIE::BufferOp, release
4700+
// immediately after the put to enable interleaved operation when multiple
4701+
// puts share the same staging buffer.
4702+
if (!tileInbound.value() && isa<AIE::BufferOp>(alloc.getDefiningOp())) {
4703+
builder.setInsertionPointAfter(memcpyOpIf);
4704+
AIE::UseLockOp::create(builder, memcpyOpIf->getLoc(), relLockOp,
4705+
AIE::LockAction::Release, lockRelValue);
4706+
} else if (auto nextWriter = findNextDmaWriteOp(memcpyOpIf, alloc)) {
47064707
// Lifetime ends if dma writes into the same buffer.
47074708
builder.setInsertionPoint(nextWriter);
47084709
AIE::UseLockOp::create(builder, nextWriter->getLoc(), relLockOp,

mlir/test/Conversion/AIRToAIE/air_shimcpy_to_npu.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,3 +1764,48 @@ module {
17641764
return
17651765
}
17661766
}
1767+
1768+
// -----
1769+
1770+
// L1-to-L3 packet flow: ChannelPut in herd (L1 source) with ChannelGet in
1771+
// launch (L3 destination). Verifies that the packet attribute is generated on
1772+
// the compute tile MM2S BD despite ChannelPut.getDstMemref() returning null.
1773+
1774+
// CHECK: aie.mem(%[[TILE:.*]]) {
1775+
// CHECK: aie.dma_start(MM2S, 0
1776+
// CHECK: aie.dma_bd(%{{.*}} : memref<64xbf16, 2>{{.*}}) {{{.*}}packet = #aie.packet_info<pkt_type = 0, pkt_id = 0>
1777+
// CHECK: aie.packet_flow(0) {
1778+
// CHECK: aie.packet_source<%[[TILE]], DMA : 0>
1779+
// CHECK: aie.packet_dest<%{{.*}}, DMA :
1780+
1781+
// RACECONDFIX: @func21
1782+
1783+
module {
1784+
air.channel @L1ToL3Pkt [1, 1] {channel_type = "dma_packet"}
1785+
func.func @func21(%arg0: memref<64xbf16>) {
1786+
%c1 = arith.constant 1 : index
1787+
%c0 = arith.constant 0 : index
1788+
%0 = air.launch async () in () args(%out=%arg0) : memref<64xbf16> attributes {id = 1 : i32} {
1789+
%lc0 = arith.constant 0 : index
1790+
// L3 destination: ChannelGet into L3 memref
1791+
%1 = air.channel.get async @L1ToL3Pkt[%lc0, %lc0] (%out[] [] []) {id = 1 : i32} : (memref<64xbf16>)
1792+
%lc1 = arith.constant 1 : index
1793+
%2 = air.segment @seg async attributes {id = 2 : i32, x_loc = 0 : i64, y_loc = 2 : i64} {
1794+
%sc1 = arith.constant 1 : index
1795+
%3 = air.herd @herd async tile (%tx, %ty) in (%htx=%sc1, %hty=%sc1) attributes {id = 3 : i32} {
1796+
%hc0 = arith.constant 0 : index
1797+
%async_token, %buf = air.execute -> (memref<64xbf16, 2>) {
1798+
%alloc = memref.alloc() : memref<64xbf16, 2>
1799+
air.execute_terminator %alloc : memref<64xbf16, 2>
1800+
}
1801+
// L1 source: ChannelPut from L1 buffer
1802+
%put = air.channel.put async [%async_token] @L1ToL3Pkt[%hc0, %hc0] (%buf[] [] []) {id = 2 : i32} : (memref<64xbf16, 2>)
1803+
%4 = air.execute [%put] {
1804+
memref.dealloc %buf : memref<64xbf16, 2>
1805+
}
1806+
}
1807+
}
1808+
}
1809+
return
1810+
}
1811+
}

0 commit comments

Comments
 (0)