diff --git a/mlir/lib/Conversion/AIRLoweringPass.cpp b/mlir/lib/Conversion/AIRLoweringPass.cpp index 90ccec230..f2c8eff9f 100644 --- a/mlir/lib/Conversion/AIRLoweringPass.cpp +++ b/mlir/lib/Conversion/AIRLoweringPass.cpp @@ -49,6 +49,11 @@ using namespace mlir; namespace xilinx { namespace air { +// Maximum number of dimensions for offsets/sizes/strides in the airrt DMA +// format. Matches the 4-element layout of airrt.dma_memcpy_nd and +// airrt.memcpy_nd (offset3..offset0, length3..length0, stride3..stride0). +static constexpr unsigned kAIRRtMaxNDims = 4; + /// Return true if \p ifOp's condition is an arith.cmpi comparing a /// scf.parallel induction variable — the segment-unroll index check pattern. static bool isSegmentUnrollCondition(scf::IfOp ifOp) { @@ -563,28 +568,47 @@ class AIRDmaMemcpyNdToAIRRtConversion auto one = arith::ConstantOp::create(rewriter, loc, i64Ty, IntegerAttr::get(i64Ty, 1)); - SmallVector offsets(4, zero); - SmallVector lengths(4, one); - SmallVector strides(4, zero); + SmallVector offsets(kAIRRtMaxNDims, zero); + SmallVector lengths(kAIRRtMaxNDims, one); + SmallVector strides(kAIRRtMaxNDims, zero); + + // The airrt format supports at most kAIRRtMaxNDims dimensions for offsets, + // sizes, and strides. When N exceeds this (e.g., from BD optimization or + // block-layout lowering), keep only the last kAIRRtMaxNDims elements. The + // leading dimensions are always zero-offset as inserted by + // foldForLoopNestAsExtendedSizesAndStrides and would silently produce + // incorrect transfers if non-zero. + auto truncateToMaxDims = [](auto range) { + return range.size() > kAIRRtMaxNDims ? range.take_back(kAIRRtMaxNDims) + : range; + }; - int idx = 4 - src.getRank(); - for (auto o : isFromTile ? op.getDstOffsets() : op.getSrcOffsets()) + auto allOffsets = isFromTile ? op.getDstOffsets() : op.getSrcOffsets(); + if (allOffsets.size() > kAIRRtMaxNDims) { + for (auto o : allOffsets.drop_back(kAIRRtMaxNDims)) { + auto v = getConstantIntValue(o); + assert((!v || *v == 0) && "dropping non-zero leading DMA offset"); + } + } + auto op_offsets = truncateToMaxDims(allOffsets); + int idx = kAIRRtMaxNDims - op_offsets.size(); + for (auto o : op_offsets) offsets[idx++] = arith::IndexCastOp::create(rewriter, op->getLoc(), IntegerType::get(ctx, 64), o); - auto op_strides = isFromTile ? op.getDstStrides() : op.getSrcStrides(); + + auto op_strides = + truncateToMaxDims(isFromTile ? op.getDstStrides() : op.getSrcStrides()); if (op_strides.size()) { - // Take last min(4, N) strides, drop leading strides if N > 4. - // The innermost stride (last element) is now preserved. - auto strides_to_use = op_strides; - if (strides_to_use.size() > 4) - strides_to_use = strides_to_use.drop_front(strides_to_use.size() - 4); - idx = 4 - strides_to_use.size(); - for (auto o : strides_to_use) + idx = kAIRRtMaxNDims - op_strides.size(); + for (auto o : op_strides) strides[idx++] = arith::IndexCastOp::create( rewriter, op->getLoc(), IntegerType::get(ctx, 64), o); } - idx = 4 - src.getRank(); - for (auto o : isFromTile ? op.getDstSizes() : op.getSrcSizes()) + + auto op_sizes = + truncateToMaxDims(isFromTile ? op.getDstSizes() : op.getSrcSizes()); + idx = kAIRRtMaxNDims - op_sizes.size(); + for (auto o : op_sizes) lengths[idx++] = arith::IndexCastOp::create(rewriter, op->getLoc(), IntegerType::get(ctx, 64), o); @@ -699,23 +723,22 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder, return failure(); } - while (offsets.size() > 4) { + while (offsets.size() > kAIRRtMaxNDims) { offsets.erase(offsets.begin()); } - while (offsets.size() < 4) { + while (offsets.size() < kAIRRtMaxNDims) { offsets.insert(offsets.begin(), zero_idx); } - while (wraps.size() > 4) { + while (wraps.size() > kAIRRtMaxNDims) { wraps.erase(wraps.begin()); } - while (wraps.size() < 4) { + while (wraps.size() < kAIRRtMaxNDims) { wraps.insert(wraps.begin(), one_idx); } - // Truncate to last 4 elements if more than 4 strides. - while (strides.size() > 4) { + while (strides.size() > kAIRRtMaxNDims) { strides.erase(strides.begin()); } - while (strides.size() < 4) { + while (strides.size() < kAIRRtMaxNDims) { strides.insert(strides.begin(), zero_idx); } diff --git a/mlir/test/Conversion/AIRLowering/air_dma_nd_6d_to_airrt.mlir b/mlir/test/Conversion/AIRLowering/air_dma_nd_6d_to_airrt.mlir new file mode 100644 index 000000000..d538569a3 --- /dev/null +++ b/mlir/test/Conversion/AIRLowering/air_dma_nd_6d_to_airrt.mlir @@ -0,0 +1,37 @@ +//===- air_dma_nd_6d_to_airrt.mlir -----------------------------*- MLIR -*-===// +// +// Copyright (C) 2022, Xilinx Inc. All rights reserved. +// Copyright (C) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +// Verify that air-to-std correctly truncates >4D offset/size/stride lists +// to 4D for airrt.dma_memcpy_nd. The BD optimization pass and block-layout +// lowering can produce 6D patterns that must be truncated to fit the 4D +// hardware BD format. + +// RUN: air-opt %s -air-to-std -cse | FileCheck %s + +// CHECK-LABEL: func.func @dma_6d +// The 6D DMA is truncated to 4D: leading 2 (trivial) dimensions are dropped. +// The type signature confirms exactly 4 elements in each bracket group. +// CHECK: airrt.dma_memcpy_nd({{.*}}) : (i32, i64, i64, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64, i64]) +module { + func.func @dma_6d(%arg0: memref<64x64xi32>) { + %c2 = arith.constant 2 : index + air.herd tile (%tx, %ty) in (%sx=%c2, %sy=%c2) args(%ext=%arg0) : memref<64x64xi32> attributes {sym_name = "herd_0"} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %buf = memref.alloc() : memref<32x64xi32, 2> + // 6D offsets/sizes/strides: leading 2 dims are trivial (offset=0, size=1). + air.dma_memcpy_nd (%buf[] [] [], %ext[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c32, %c64, %c1] [%c0, %c0, %c2048, %c64, %c1, %c0]) {id = 1 : i32} : (memref<32x64xi32, 2>, memref<64x64xi32>) + memref.dealloc %buf : memref<32x64xi32, 2> + } + return + } +}