Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ using namespace mlir;
namespace xilinx {
namespace air {

// Maximum number of dimensions for offsets/sizes/strides in the airrt DMA
// format. Matches the 4-element layout of airrt.dma_memcpy_nd and
// airrt.memcpy_nd (offset3..offset0, length3..length0, stride3..stride0).
static constexpr unsigned kAIRRtMaxNDims = 4;

/// Return true if \p ifOp's condition is an arith.cmpi comparing a
/// scf.parallel induction variable — the segment-unroll index check pattern.
static bool isSegmentUnrollCondition(scf::IfOp ifOp) {
Expand Down Expand Up @@ -563,28 +568,47 @@ class AIRDmaMemcpyNdToAIRRtConversion
auto one = arith::ConstantOp::create(rewriter, loc, i64Ty,
IntegerAttr::get(i64Ty, 1));

SmallVector<Value, 4> offsets(4, zero);
SmallVector<Value, 4> lengths(4, one);
SmallVector<Value, 4> strides(4, zero);
SmallVector<Value, 4> offsets(kAIRRtMaxNDims, zero);
SmallVector<Value, 4> lengths(kAIRRtMaxNDims, one);
SmallVector<Value, 4> strides(kAIRRtMaxNDims, zero);

// The airrt format supports at most kAIRRtMaxNDims dimensions for offsets,
// sizes, and strides. When N exceeds this (e.g., from BD optimization or
// block-layout lowering), keep only the last kAIRRtMaxNDims elements. The
// leading dimensions are always zero-offset as inserted by
// foldForLoopNestAsExtendedSizesAndStrides and would silently produce
// incorrect transfers if non-zero.
auto truncateToMaxDims = [](auto range) {
return range.size() > kAIRRtMaxNDims ? range.take_back(kAIRRtMaxNDims)
: range;
};

int idx = 4 - src.getRank();
for (auto o : isFromTile ? op.getDstOffsets() : op.getSrcOffsets())
auto allOffsets = isFromTile ? op.getDstOffsets() : op.getSrcOffsets();
if (allOffsets.size() > kAIRRtMaxNDims) {
for (auto o : allOffsets.drop_back(kAIRRtMaxNDims)) {
auto v = getConstantIntValue(o);
assert((!v || *v == 0) && "dropping non-zero leading DMA offset");
}
}
auto op_offsets = truncateToMaxDims(allOffsets);
int idx = kAIRRtMaxNDims - op_offsets.size();
for (auto o : op_offsets)
offsets[idx++] = arith::IndexCastOp::create(rewriter, op->getLoc(),
IntegerType::get(ctx, 64), o);
auto op_strides = isFromTile ? op.getDstStrides() : op.getSrcStrides();

auto op_strides =
truncateToMaxDims(isFromTile ? op.getDstStrides() : op.getSrcStrides());
if (op_strides.size()) {
// Take last min(4, N) strides, drop leading strides if N > 4.
// The innermost stride (last element) is now preserved.
auto strides_to_use = op_strides;
if (strides_to_use.size() > 4)
strides_to_use = strides_to_use.drop_front(strides_to_use.size() - 4);
idx = 4 - strides_to_use.size();
for (auto o : strides_to_use)
idx = kAIRRtMaxNDims - op_strides.size();
for (auto o : op_strides)
strides[idx++] = arith::IndexCastOp::create(
rewriter, op->getLoc(), IntegerType::get(ctx, 64), o);
}
idx = 4 - src.getRank();
for (auto o : isFromTile ? op.getDstSizes() : op.getSrcSizes())

auto op_sizes =
truncateToMaxDims(isFromTile ? op.getDstSizes() : op.getSrcSizes());
idx = kAIRRtMaxNDims - op_sizes.size();
for (auto o : op_sizes)
lengths[idx++] = arith::IndexCastOp::create(rewriter, op->getLoc(),
IntegerType::get(ctx, 64), o);

Expand Down Expand Up @@ -699,23 +723,22 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder,
return failure();
}

while (offsets.size() > 4) {
while (offsets.size() > kAIRRtMaxNDims) {
offsets.erase(offsets.begin());
}
while (offsets.size() < 4) {
while (offsets.size() < kAIRRtMaxNDims) {
offsets.insert(offsets.begin(), zero_idx);
}
while (wraps.size() > 4) {
while (wraps.size() > kAIRRtMaxNDims) {
wraps.erase(wraps.begin());
}
while (wraps.size() < 4) {
while (wraps.size() < kAIRRtMaxNDims) {
wraps.insert(wraps.begin(), one_idx);
}
// Truncate to last 4 elements if more than 4 strides.
while (strides.size() > 4) {
while (strides.size() > kAIRRtMaxNDims) {
strides.erase(strides.begin());
}
while (strides.size() < 4) {
while (strides.size() < kAIRRtMaxNDims) {
strides.insert(strides.begin(), zero_idx);
}

Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Conversion/AIRLowering/air_dma_nd_6d_to_airrt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- air_dma_nd_6d_to_airrt.mlir -----------------------------*- MLIR -*-===//
//
// Copyright (C) 2022, Xilinx Inc. All rights reserved.
// Copyright (C) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//

// Verify that air-to-std correctly truncates >4D offset/size/stride lists
// to 4D for airrt.dma_memcpy_nd. The BD optimization pass and block-layout
// lowering can produce 6D patterns that must be truncated to fit the 4D
// hardware BD format.

// RUN: air-opt %s -air-to-std -cse | FileCheck %s

// CHECK-LABEL: func.func @dma_6d
// The 6D DMA is truncated to 4D: leading 2 (trivial) dimensions are dropped.
// The type signature confirms exactly 4 elements in each bracket group.
// CHECK: airrt.dma_memcpy_nd({{.*}}) : (i32, i64, i64, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64, i64])
module {
func.func @dma_6d(%arg0: memref<64x64xi32>) {
%c2 = arith.constant 2 : index
air.herd tile (%tx, %ty) in (%sx=%c2, %sy=%c2) args(%ext=%arg0) : memref<64x64xi32> attributes {sym_name = "herd_0"} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%buf = memref.alloc() : memref<32x64xi32, 2>
// 6D offsets/sizes/strides: leading 2 dims are trivial (offset=0, size=1).
air.dma_memcpy_nd (%buf[] [] [], %ext[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c32, %c64, %c1] [%c0, %c0, %c2048, %c64, %c1, %c0]) {id = 1 : i32} : (memref<32x64xi32, 2>, memref<64x64xi32>)
memref.dealloc %buf : memref<32x64xi32, 2>
}
return
}
}
Loading