Skip to content

Commit dcbe37e

Browse files
erwei-xilinxclaude
andauthored
[airrt-to-npu] Honor inner-dim alignment when tiling oversized wraps (#1586)
* [airrt-to-npu] Honor inner-dim alignment when tiling oversized wraps `tileIllegalWrapDim` currently calls `findLargestFactor(wrap, 1023)` to split a wrap that exceeds the shim 10-bit limit. For a contiguous bf16 transfer of length 131136 the factor it picks is 683 (an odd prime), producing an inner segment of 683 elem * 2 B = 1366 B that fails the `aie.dma_bd` 4-byte-alignment verifier. Add `findLargestAlignedFactor`, used only when tiling the contiguous innermost dim (`stride == 1`), with `alignment = addressGenGranularity / elemBits` (= 2 for bf16, 1 for f32, 4 for i8). For the bf16 length-131136 case it now picks 192 instead, yielding `[<size=683, stride=192>, <size=192, stride=1>]` — both dims 4-B aligned. This bug fires for any bf16 / sub-32-bit transfer whose length factors cleanly only into an odd prime above ~511. Surfaced by an LLaMA-3.2-1B decode attention design where the K/V cache load is `(pos+1) * head_dim` and `pos+1 = 2049 = 3 * 683`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Emit diagnostic when no aligned factor exists Drop the silent fallback in `findLargestAlignedFactor` and have it return 0 when no factor of `num` in `[alignment, max]` is a multiple of `alignment`. Plumb the failure through `tileIllegalWrapDim` and `enforceAIE2WrapLimit` so the pass emits an op-level error and `signalPassFailure()` instead of producing IR that the downstream `aie.dma_bd` verifier rejects with a generic alignment message. The diagnostic names the offending dim, the size, the legal range, and the byte ratio (shim address granularity / element size), so the user knows whether to reshape the transfer or pad the inner dimension. Add a third sub-test exercising bf16 length 2049 (= 3 * 683) — the only factors <= 1023 are 1, 3, and 683, all odd, so no aligned factor exists and the diagnostic fires. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Consolidate alignment-aware tiling into air::Util; cover canonicalizeWrapAndStrideList - Move findLargestAlignedFactor into air:: (Util.h/Util.cpp); delete the private duplicate of findLargestFactor in AIRRtToNpuPass.cpp that had been accumulating helpers next to the canonical one in air::. - Add air::getDmaInnerElementAlignment(memrefTy, op) so both tileIllegalWrapDim and canonicalizeWrapAndStrideList derive alignment the same way (DataLayout for the element width, fixed 32-bit shim address granularity matching AIE2/AIE2P AIETargetModel::getAddressGenGranularity). - Apply the alignment fix to air::canonicalizeWrapAndStrideList via a new innerAlignment parameter (default 1, no behavior change for callers that don't opt in). Update the three shim-bound call sites in AIRToAIEPass.cpp and AIRDependencyScheduleOpt.cpp to pass the derived alignment so the bug is caught earlier in the pipeline. When no aligned factor exists at this layer, leave the dim oversized so AIRRtToNpuPass emits the diagnostic with full op context — avoids a bare LogicalResult failure that callers ignore. - Tighten the inline comment in tileIllegalWrapDim now that the bug story lives in the commit message and the helper docstring. - Add an i8 sub-test (length 1028 = 4*257; alignment=4 forces inner wrap to drop from 514 to 4) and an NPU2 sub-test (mirrors the bf16 case, guarding against a future device divergence in addressGenGranularity). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Query AIE target model for shim address-gen granularity when reachable getDmaInnerElementAlignment now consults the parent AIE::DeviceOp's target model via getAddressGenGranularity() instead of hardcoding 32. The hardcoded 32 stays as a fallback for when the op has no DeviceOp ancestor (early-pipeline contexts) or when AIR is built with AIR_ENABLE_AIE=OFF. Pulls AIE into AIRUtil's link libs conditionally on AIR_ENABLE_AIE, matching the pattern used by AIRConversionPasses and AIRTransformPasses. The include is similarly guarded so the AIE-disabled build still works. For both AIE2 and AIE2P (the only current devices) this reads the same 32 the fallback would have produced, so no test output changes — but a future device with a different addressGenGranularity will now Just Work without anyone having to remember to update a constant. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent fd62d7c commit dcbe37e

7 files changed

Lines changed: 317 additions & 54 deletions

File tree

mlir/include/air/Util/Util.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,33 @@ LogicalResult foldForLoopNestAsExtendedSizesAndStrides(
192192
// Find the largest factor of 'num' which is not larger than 'max'.
193193
int findLargestFactor(int num, int max);
194194

195+
// Largest factor of 'num' that is <= 'max' and a multiple of 'alignment'.
196+
// Returns 0 when no aligned factor exists, so the caller can emit a
197+
// diagnostic instead of silently producing misaligned IR. With alignment<=1
198+
// behaves as findLargestFactor.
199+
int findLargestAlignedFactor(int num, int max, int alignment);
200+
201+
// Element-count alignment required so that an inner DMA wrap stays a
202+
// multiple of the AIE shim address granularity. Queries the parent
203+
// AIE::DeviceOp's target model when reachable (preferred); otherwise falls
204+
// back to 32 bits (the AIE2 / AIE2P value). Returns 1 when each element
205+
// already meets or exceeds the granularity (e.g. f32, i32); 2 for bf16/i16;
206+
// 4 for i8/ui8. 'op' is used both to find the DeviceOp ancestor and to
207+
// resolve the DataLayout via DataLayout::closest().
208+
int getDmaInnerElementAlignment(mlir::BaseMemRefType memrefTy,
209+
mlir::Operation *op);
210+
195211
// Canonicalize wrap and stride lists, by removing redundant dimensions.
196-
LogicalResult canonicalizeWrapAndStrideList(
197-
OpBuilder &builder, SmallVector<Value> &offsets, SmallVector<Value> &sizes,
198-
SmallVector<Value> &strides, int memref_volume, int maxSize = -1);
212+
// 'innerAlignment' constrains the contiguous innermost dim (stride==1) when
213+
// it must be split: the new inner wrap is forced to be a multiple of
214+
// 'innerAlignment' elements (e.g. 2 for bf16 on a 4-byte shim BD). Pass 1
215+
// (the default) when no extra constraint applies.
216+
LogicalResult canonicalizeWrapAndStrideList(OpBuilder &builder,
217+
SmallVector<Value> &offsets,
218+
SmallVector<Value> &sizes,
219+
SmallVector<Value> &strides,
220+
int memref_volume, int maxSize = -1,
221+
int innerAlignment = 1);
199222

200223
// If wrap-and-stride lists are empty, populate them with default data access
201224
// layout (contiguous, row-major).

mlir/lib/Conversion/AIRRtToNpuPass.cpp

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,40 +1179,7 @@ bool violatesAIE2WrapLimit(airrt::DmaMemcpyNdOp dma) {
11791179
return false;
11801180
}
11811181

1182-
// Find the largest factor of 'num' which is not larger than 'max'. Ref:
1183-
// https://github.com/nod-ai/iree-amd-aie/blob/main/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEUtils.cpp#L334
1184-
int findLargestFactor(int num, int max) {
1185-
// No factors less than or equal to 0 exist
1186-
if (max <= 0)
1187-
return 0;
1188-
1189-
// Do O(1) instead of O(sqrt(num)) computation for this common case.
1190-
if (num <= max) {
1191-
return num;
1192-
}
1193-
1194-
int largestLowFactor = 1;
1195-
for (int lowFactor = 2; lowFactor <= max; ++lowFactor) {
1196-
const int highFactor = num / lowFactor;
1197-
1198-
// This early exit is what makes this O(sqrt(num)) instead of O(num).
1199-
if (highFactor < lowFactor)
1200-
return largestLowFactor;
1201-
1202-
const bool areActuallyFactors = num % lowFactor == 0;
1203-
if (areActuallyFactors) {
1204-
// We're certain that here lowFactor <= highFactor, and highFactor is
1205-
// descending in this loop. So we can return immediately if highFactor
1206-
// is good.
1207-
if (highFactor <= max)
1208-
return highFactor;
1209-
largestLowFactor = lowFactor;
1210-
}
1211-
}
1212-
return largestLowFactor;
1213-
}
1214-
1215-
void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
1182+
LogicalResult tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
12161183
auto loc = memcpy_op->getLoc();
12171184
auto ctx = memcpy_op->getContext();
12181185
auto oper_begin = memcpy_op.getOperands().begin();
@@ -1221,13 +1188,30 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
12211188
SmallVector<Value> strides(oper_begin + 12, oper_begin + 16);
12221189
OpBuilder builder(memcpy_op);
12231190

1191+
auto memrefTy =
1192+
llvm::dyn_cast<BaseMemRefType>(memcpy_op.getMemref().getType());
1193+
int innerAlignment =
1194+
memrefTy ? air::getDmaInnerElementAlignment(memrefTy, memcpy_op) : 1;
1195+
12241196
for (int i = wraps.size() - 1; i >= 0; i--) {
12251197
auto const_wrap = *getConstantIntValue(wraps[i]);
12261198
auto const_stride = *getConstantIntValue(strides[i]);
12271199
if (const_wrap >= AIE2_WRAP_UPPER_BOUNDS[i]) {
1228-
// Found dimension with illegal wrap. Tiling. (Prefers smaller outer
1229-
// wrap values, as long as stride fits)
1230-
int a_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUNDS[i] - 1);
1200+
// Found dimension with illegal wrap. Prefers smaller outer wrap as
1201+
// long as stride fits. For stride==1, force the inner wrap to a
1202+
// multiple of innerAlignment elements so its byte size stays aligned
1203+
// to the shim address granularity (otherwise aie.dma_bd rejects it).
1204+
int alignment = (const_stride == 1) ? innerAlignment : 1;
1205+
int a_wrap = air::findLargestAlignedFactor(
1206+
const_wrap, AIE2_WRAP_UPPER_BOUNDS[i] - 1, alignment);
1207+
if (a_wrap == 0) {
1208+
return memcpy_op.emitOpError()
1209+
<< "cannot tile dim " << i << " of size " << const_wrap
1210+
<< " into shim-legal chunks: no factor in [" << alignment << ", "
1211+
<< (AIE2_WRAP_UPPER_BOUNDS[i] - 1) << "] is a multiple of "
1212+
<< alignment
1213+
<< " elements. Reshape the transfer or pad the inner dimension.";
1214+
}
12311215
int b_wrap = llvm::divideCeilSigned(const_wrap, a_wrap);
12321216
int new_a_stride = const_stride * a_wrap;
12331217
auto volume = air::getTensorVolume(
@@ -1357,9 +1341,10 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
13571341
}
13581342

13591343
memcpy_op.erase();
1344+
return success();
13601345
}
13611346

1362-
void enforceAIE2WrapLimit(ModuleOp module) {
1347+
LogicalResult enforceAIE2WrapLimit(ModuleOp module) {
13631348
// Identify airrt.dma_memcpy_nd ops that violate the AIE2 wrap size
13641349
// constraint.
13651350
SmallVector<airrt::DmaMemcpyNdOp> target_airrt_dmas;
@@ -1374,7 +1359,9 @@ void enforceAIE2WrapLimit(ModuleOp module) {
13741359

13751360
// Enforce the AIE2 wrap limit by tiling that dimension.
13761361
for (auto memcpy_op : target_airrt_dmas)
1377-
tileIllegalWrapDim(memcpy_op);
1362+
if (failed(tileIllegalWrapDim(memcpy_op)))
1363+
return failure();
1364+
return success();
13781365
}
13791366

13801367
struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
@@ -1455,7 +1442,10 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
14551442
generateNpuWaitFromAIRRtWaitAll(module);
14561443

14571444
// Enforce AIE2 hardware constraints.
1458-
enforceAIE2WrapLimit(module);
1445+
if (failed(enforceAIE2WrapLimit(module))) {
1446+
signalPassFailure();
1447+
return;
1448+
}
14591449

14601450
// Simplify arith ops (from airrt)
14611451
RewritePatternSet canoPatterns_3(ctx);

mlir/lib/Conversion/AIRToAIEPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2686,9 +2686,13 @@ struct SpecializeChannelBundlePattern
26862686
SmallVector<Value> offsets = ci.getOffsets();
26872687
SmallVector<Value> wraps = ci.getSizes();
26882688
SmallVector<Value> strides = ci.getStrides();
2689+
auto memrefTy = llvm::dyn_cast<BaseMemRefType>(ci.getMemref().getType());
2690+
int innerAlignment =
2691+
memrefTy ? air::getDmaInnerElementAlignment(memrefTy, ci) : 1;
26892692
(void)air::canonicalizeWrapAndStrideList(
26902693
builder, offsets, wraps, strides,
2691-
air::getTensorVolume(ci.getMemref().getType()), maxSize);
2694+
air::getTensorVolume(ci.getMemref().getType()), maxSize,
2695+
innerAlignment);
26922696
air::ChannelInterface new_ci = nullptr;
26932697
if (isa<air::ChannelPutOp>(ci))
26942698
new_ci = air::ChannelPutOp::create(

mlir/lib/Transform/AIRDependencyScheduleOpt.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,9 +2196,14 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
21962196
SmallVector<Value> strides = channel_op.getStrides();
21972197

21982198
OpBuilder b(channel_op);
2199+
auto memrefTy =
2200+
llvm::dyn_cast<BaseMemRefType>(channel_op.getMemref().getType());
2201+
int innerAlignment =
2202+
memrefTy ? air::getDmaInnerElementAlignment(memrefTy, channel_op) : 1;
21992203
(void)canonicalizeWrapAndStrideList(
22002204
b, offsets, wraps, strides,
2201-
air::getTensorVolume(channel_op.getMemref().getType()), maxSize);
2205+
air::getTensorVolume(channel_op.getMemref().getType()), maxSize,
2206+
innerAlignment);
22022207

22032208
// If empty offsets/sizes/strides, then populate the lists with default
22042209
// values.
@@ -2225,7 +2230,8 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
22252230

22262231
(void)canonicalizeWrapAndStrideList(
22272232
rewriter, offsets, wraps, strides,
2228-
air::getTensorVolume(channel_op.getMemref().getType()), maxSize);
2233+
air::getTensorVolume(channel_op.getMemref().getType()), maxSize,
2234+
innerAlignment);
22292235

22302236
// Whether repeat (i.e. stride = 0) is supported at highest dimension.
22312237
if (enableRepeatAtHighestDim && !wraps.empty()) {
@@ -2605,9 +2611,13 @@ struct AIRCanonicalizeChannelPutGetOpWrapAndStrideList
26052611
if (padBeforeCheck)
26062612
return failure();
26072613
// Canonicalize offsets/sizes/strides using a helper function.
2614+
auto memrefTy = llvm::dyn_cast<BaseMemRefType>(op.getMemref().getType());
2615+
int innerAlignment =
2616+
memrefTy ? air::getDmaInnerElementAlignment(memrefTy, op) : 1;
26082617
if (failed(canonicalizeWrapAndStrideList(
26092618
rewriter, offsets, sizes, strides,
2610-
air::getTensorVolume(op.getMemref().getType()), maxSize)))
2619+
air::getTensorVolume(op.getMemref().getType()), maxSize,
2620+
innerAlignment)))
26112621
return failure();
26122622

26132623
// When highest-dimension repeat is active, pad offsets/sizes/strides to

mlir/lib/Util/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22
# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
33
# SPDX-License-Identifier: MIT
44

5+
set(AIRUTIL_LINK_LIBS
6+
MLIRIR
7+
MLIRTransforms
8+
)
9+
10+
# AIE target model is queried for the shim address-gen granularity when
11+
# computing DMA inner-element alignment. Conditional on AIR_ENABLE_AIE; the
12+
# fallback path in Util.cpp is used when AIE is disabled or no DeviceOp is
13+
# reachable from the op.
14+
if(AIR_ENABLE_AIE)
15+
list(APPEND AIRUTIL_LINK_LIBS AIE)
16+
endif()
17+
518
add_mlir_library(AIRUtil
619
Util.cpp
720
Outliner.cpp
@@ -15,6 +28,5 @@ add_mlir_library(AIRUtil
1528
AIRDialect
1629

1730
LINK_LIBS PUBLIC
18-
MLIRIR
19-
MLIRTransforms
31+
${AIRUTIL_LINK_LIBS}
2032
)

mlir/lib/Util/Util.cpp

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
#include "air/Util/Util.h"
1010
#include "air/Dialect/AIR/AIRDialect.h"
1111

12+
#if AIR_ENABLE_AIE
13+
#include "aie/Dialect/AIE/IR/AIEDialect.h"
14+
#include "aie/Dialect/AIE/IR/AIETargetModel.h"
15+
#endif
16+
1217
#include "mlir/Analysis/SliceAnalysis.h"
1318
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1419
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -21,6 +26,7 @@
2126
#include "mlir/IR/IntegerSet.h"
2227
#include "mlir/IR/Iterators.h"
2328
#include "mlir/IR/OperationSupport.h"
29+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
2430

2531
#include "llvm/ADT/SmallPtrSet.h"
2632
#include "llvm/Support/Debug.h"
@@ -1126,10 +1132,52 @@ int air::findLargestFactor(int num, int max) {
11261132
return largestLowFactor;
11271133
}
11281134

1135+
// Fallback shim address-gen granularity when we can't reach an AIE::DeviceOp
1136+
// to query the target model. Matches AIETargetModel::getAddressGenGranularity
1137+
// for AIE2 and AIE2P. The dynamic lookup below is preferred when available so
1138+
// future devices with a different value just work.
1139+
static constexpr unsigned kAIEShimAddrGenBitsFallback = 32;
1140+
1141+
int air::getDmaInnerElementAlignment(BaseMemRefType memrefTy, Operation *op) {
1142+
if (!memrefTy || !op)
1143+
return 1;
1144+
DataLayout dl = DataLayout::closest(op);
1145+
unsigned elemBits = dl.getTypeSizeInBits(memrefTy.getElementType());
1146+
if (elemBits == 0)
1147+
return 1;
1148+
unsigned addrGenBits = kAIEShimAddrGenBitsFallback;
1149+
#if AIR_ENABLE_AIE
1150+
if (auto dev = op->getParentOfType<AIE::DeviceOp>())
1151+
addrGenBits = dev.getTargetModel().getAddressGenGranularity();
1152+
#endif
1153+
if (elemBits >= addrGenBits)
1154+
return 1;
1155+
return addrGenBits / elemBits;
1156+
}
1157+
1158+
// Largest factor of 'num' that is <= 'max' and a multiple of 'alignment'.
1159+
// See header for rationale.
1160+
int air::findLargestAlignedFactor(int num, int max, int alignment) {
1161+
if (alignment <= 1)
1162+
return findLargestFactor(num, max);
1163+
if (max < alignment)
1164+
return 0;
1165+
int alignedMax = (max / alignment) * alignment;
1166+
for (int candidate = alignedMax; candidate >= alignment;
1167+
candidate -= alignment) {
1168+
if (num % candidate == 0)
1169+
return candidate;
1170+
}
1171+
return 0;
1172+
}
1173+
11291174
// Canonicalize wrap and stride lists by removing redundant dimensions.
1130-
LogicalResult air::canonicalizeWrapAndStrideList(
1131-
OpBuilder &builder, SmallVector<Value> &offsets, SmallVector<Value> &sizes,
1132-
SmallVector<Value> &strides, int memref_volume, int maxSize) {
1175+
LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder &builder,
1176+
SmallVector<Value> &offsets,
1177+
SmallVector<Value> &sizes,
1178+
SmallVector<Value> &strides,
1179+
int memref_volume, int maxSize,
1180+
int innerAlignment) {
11331181
// AIE2 hardware constraints. TODO: import these info from target model.
11341182
const int AIE2_STRIDE_UPPER_BOUND = 1048576;
11351183
bool listsHaveChanged = false;
@@ -1159,8 +1207,20 @@ LogicalResult air::canonicalizeWrapAndStrideList(
11591207
if (const_wrap <= maxSize)
11601208
continue;
11611209
// Found dimension with illegal wrap. Tiling. (Prefers smaller outer wrap
1162-
// values, as long as stride fits)
1163-
int a_wrap = findLargestFactor(const_wrap, maxSize);
1210+
// values, as long as stride fits.) For the contiguous innermost dim
1211+
// (stride==1), require the new inner wrap to be a multiple of
1212+
// innerAlignment elements so the resulting d0 byte size stays aligned to
1213+
// the shim address granularity (e.g. 4 B for bf16 / i8). Falls back to
1214+
// findLargestFactor when innerAlignment <= 1 or stride != 1.
1215+
int a_wrap =
1216+
(const_stride == 1)
1217+
? findLargestAlignedFactor(const_wrap, maxSize, innerAlignment)
1218+
: findLargestFactor(const_wrap, maxSize);
1219+
// No aligned factor exists. Leave the dim oversized and let the
1220+
// downstream shim lowering (tileIllegalWrapDim in AIRRtToNpuPass) emit
1221+
// the diagnostic with full op context.
1222+
if (a_wrap == 0)
1223+
continue;
11641224
int b_wrap = llvm::divideCeilSigned(const_wrap, a_wrap);
11651225
int new_a_stride = const_stride * a_wrap;
11661226
if (memref_volume != 1)

0 commit comments

Comments
 (0)