diff --git a/mlir/include/air/Transform/AIRLinalgBufferize.h b/mlir/include/air/Transform/AIRLinalgBufferize.h index 299fc29f8..67f9c9d5c 100644 --- a/mlir/include/air/Transform/AIRLinalgBufferize.h +++ b/mlir/include/air/Transform/AIRLinalgBufferize.h @@ -10,6 +10,8 @@ #include "air/Transform/PassDetail.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" namespace xilinx { @@ -18,6 +20,14 @@ namespace air { std::unique_ptr createAIRresolveTensorOpOperandConflictsWithNewTensors(); +/// Hoist statically-bound `memref.alloc` ops out of nested loops into the +/// function entry block. Wrapper around the file-scope template +/// `hoistStaticallyBoundAllocationsInFunc`. Used both by +/// `transform.air.hoist_static_alloc` (single-shot) and the +/// `air-hoist-static-alloc` pass. +void hoistStaticAllocsInFunc(::mlir::RewriterBase &rewriter, + ::mlir::FunctionOpInterface funcOp); + } // namespace air } // namespace xilinx diff --git a/mlir/include/air/Transform/AIRMatmulBufferizationPasses.h b/mlir/include/air/Transform/AIRMatmulBufferizationPasses.h new file mode 100644 index 000000000..75cf502dd --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulBufferizationPasses.h @@ -0,0 +1,56 @@ +//===- AIRMatmulBufferizationPasses.h ---------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Free-function bodies invoked by the air-matmul-codegen orchestrator: +// bufferization to L1/L2 allocations, post-bufferize cleanup, ping-pong +// loop fusion, and bf16-output truncf fusion. +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_BUFFERIZATION_PASSES_H +#define AIR_MATMUL_BUFFERIZATION_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringRef.h" + +namespace xilinx { +namespace air { + +// Free-function bodies for the now-internal pass impls. Called from +// option-driven steps in parametric passes (pack-and-transpose, +// prologue-epilogue, tile-for-vectorize, bufferize-output-l2). +mlir::LogicalResult runFusePingpongLoopsImpl(mlir::func::FuncOp f, + mlir::RewriterBase &rewriter); +void runFuseOutputTruncfImpl(mlir::func::FuncOp f, + mlir::RewriterBase &rewriter); +void runHoistStaticAllocImpl(mlir::func::FuncOp f, + mlir::RewriterBase &rewriter); +mlir::LogicalResult runBufferizeL1OutputImpl(mlir::func::FuncOp f, + int64_t memorySpace, + llvm::StringRef packedMatmulMarker, + mlir::RewriterBase &rewriter); +mlir::LogicalResult runPostBufferizeCleanupImpl(mlir::func::FuncOp f, + mlir::RewriterBase &rewriter); + +mlir::LogicalResult runBufferizeOutputL2Impl( + mlir::func::FuncOp f, int64_t memorySpace, bool fuseOutputTruncfFirst, + bool doTileL3ToL2Copies, int64_t kL2Tile, llvm::StringRef copyALoopMarker, + llvm::StringRef copyBLoopMarker, mlir::RewriterBase &rewriter); + +mlir::LogicalResult runBufferizeL1InputsImpl(mlir::func::FuncOp f, + int64_t memorySpace, + llvm::StringRef memcpyOp, + llvm::StringRef lhsMarker, + llvm::StringRef rhsMarker, + mlir::RewriterBase &rewriter); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_BUFFERIZATION_PASSES_H diff --git a/mlir/include/air/Transform/AIRMatmulCodegen.h b/mlir/include/air/Transform/AIRMatmulCodegen.h new file mode 100644 index 000000000..ada5487fb --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulCodegen.h @@ -0,0 +1,33 @@ +//===- AIRMatmulCodegen.h ---------------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// AIRMatmulCodegen: single public matmul codegen pass. Orchestrates the +// internal phases (launch tile, pack, K-tile, core tile, prologue/epilogue, +// bufferization, vectorize) in fixed order. Internal phases are exposed as +// free functions in their respective headers. +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_CODEGEN_H +#define AIR_MATMUL_CODEGEN_H + +#include "air/Transform/PassDetail.h" + +#include "mlir/Pass/Pass.h" +#include + +namespace xilinx { +namespace air { + +std::unique_ptr createAIRMatmulCodegenPass(); +std::unique_ptr +createAIRMatmulCodegenPass(const AIRMatmulCodegenOptions &); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_CODEGEN_H diff --git a/mlir/include/air/Transform/AIRMatmulCodegenHelpers.h b/mlir/include/air/Transform/AIRMatmulCodegenHelpers.h new file mode 100644 index 000000000..244a2aef5 --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulCodegenHelpers.h @@ -0,0 +1,154 @@ +//===- AIRMatmulCodegenHelpers.h --------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Free C++ entry points for the matmul codegen transformations originally +// defined as transform.air.* op apply() bodies in AIRLinalgCodegen.cpp. +// Both the existing transform ops and the new air-matmul-* C++ passes call +// these. New helpers are added here as their corresponding apply() body is +// migrated; until migrated, the apply() retains its original logic. +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_CODEGEN_HELPERS_H +#define AIR_MATMUL_CODEGEN_HELPERS_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +namespace xilinx { +namespace air { + +//===----------------------------------------------------------------------===// +// Pure utilities used by multiple codegen helpers. +//===----------------------------------------------------------------------===// + +/// True if any operation between `firstRead` and `secondRead` (in the same +/// block) writes to `firstRead`'s base memref. +bool hasWritesBetweenReads(::mlir::vector::TransferReadOp firstRead, + ::mlir::vector::TransferReadOp secondRead); + +//===----------------------------------------------------------------------===// +// Free functions backing both transform.air.* ops and air-matmul-* passes. +//===----------------------------------------------------------------------===// + +/// Greedily fold unit-extent dims in linalg ops on `funcOp`, using a +/// memref-aware collapse function (rank-reducing subview for strided memrefs). +::mlir::LogicalResult runFoldUnitExtentDimsOnFunc(::mlir::func::FuncOp funcOp); + +/// Walk all vector.transfer_read in `target` and replace each pair of +/// identical reads with no intervening writes by the first read. Returns +/// the number of eliminations performed. +int runEliminateRedundantVectorTransfers(::mlir::Operation *target, + ::mlir::RewriterBase &rewriter); + +/// Replace vector-typed iter_args of `forOp` with their 1D-flattened form, +/// inserting vector.shape_cast at the loop entry/exit and inside the loop +/// body to convert back to the original shape. Returns the (possibly new) +/// scf.for, or `forOp` unchanged if there were no vector iter_args. +::mlir::FailureOr<::mlir::scf::ForOp> +runFlattenForIterArgs(::mlir::scf::ForOp forOp, ::mlir::RewriterBase &rewriter); + +/// Iteratively hoist matched vector.transfer_read/write pairs whose indices +/// are loop-invariant out of `loopOp` (which must live inside `scopeOp`), +/// threading the accumulator through a new iter_arg. Returns the new loop. +::mlir::FailureOr<::mlir::scf::ForOp> +runHoistLoopInvariantTransfers(::mlir::Operation *scopeOp, + ::mlir::scf::ForOp loopOp, + ::mlir::RewriterBase &rewriter); + +/// Hoist subview/affine.apply chains for vector transfer base pointers out +/// of `forOp` when they are loop-invariant. Returns the (possibly new) +/// scf.for via the rewriter; returns success/failure. +::mlir::LogicalResult +runHoistVectorTransferPointers(::mlir::scf::ForOp forOp, + ::mlir::RewriterBase &rewriter); + +/// Cast vector-typed operands (at `inputIndices`) and/or vector-typed results +/// (at `outputIndices`) of `target` to `targetElementType`, then re-create +/// the op with the casted operand/result types. Empty index lists mean +/// "cast all inputs and outputs". Used for BFP16-mmul emulation: cast +/// vector.contract inputs to bf16 + accumulator/output to f32. +/// Returns success even when the op needs no change; returns failure on +/// validation errors (target has no vector types, etc). +::mlir::LogicalResult runVectorTypeCastOnTarget( + ::mlir::Operation *target, ::mlir::Type targetElementType, + ::llvm::ArrayRef inputIndices, + ::llvm::ArrayRef outputIndices, ::mlir::RewriterBase &rewriter); + +/// Hoist an extension/truncation pair surrounding a loop iter_arg out of +/// `loopOp`: extend the init value before the loop, change the iter_arg to +/// wide type, truncate the result after the loop. `extensionOp` must be +/// arith.extsi/extui/extf and `truncationOp` the matching truncation; both +/// must live inside `loopOp`. Returns the new scf.for on success. +::mlir::FailureOr<::mlir::scf::ForOp> +runHoistCastPair(::mlir::Operation *extensionOp, + ::mlir::Operation *truncationOp, ::mlir::scf::ForOp loopOp, + ::mlir::RewriterBase &rewriter); + +//===----------------------------------------------------------------------===// +// Bufferization & fusion utilities used by the air-matmul-codegen +// orchestrator phases. +//===----------------------------------------------------------------------===// + +/// Apply OptimizeCopyOpPattern to remove copies whose source is uninitialized +/// (or only filled), replacing them with linalg.fill. Operates greedily on +/// `funcOp`. +::mlir::LogicalResult runRemoveUninitializedCopy(::mlir::func::FuncOp funcOp); + +/// Apply EliminateIntermediateMemrefPattern to collapse cascade memcpy +/// sequences (intermediate memref alloc + double copy) on `target`. +::mlir::LogicalResult runEliminateCascadeMemcpy(::mlir::Operation *target); + +/// Apply ConvertMemrefCopyToLinalgCopyPattern: rewrite memref.copy to +/// linalg.copy on `target`. Required before tile-using-for of L3->L2 copies +/// (TilingInterface lives on linalg.copy, not memref.copy). +::mlir::LogicalResult +runConvertMemrefCopyToLinalgCopy(::mlir::Operation *target); + +/// Tile-and-fuse `producerOp` (a LinalgOp with one DPS init) into the first +/// memref.subview use found inside `containingOp` (typically an scf.for/forall +/// body). Returns the tiled fused op on success, nullptr on failure. +::mlir::Operation *runFuseIntoContainingMemref(::mlir::Operation *producerOp, + ::mlir::Operation *containingOp, + ::mlir::RewriterBase &rewriter); + +/// True iff `linalgOp`'s body contains exactly one non-terminator op and that +/// op is arith.truncf. Used to identify "truncf-only" linalg ops eligible for +/// fusion into their producer. +bool containsOnlyTruncfOp(::mlir::linalg::LinalgOp linalgOp); + +/// True iff `producerOp` produces a single result that is consumed by +/// `truncfOp` as one of its DPS inputs. +bool producesResultForOp(::mlir::linalg::LinalgOp producerOp, + ::mlir::linalg::LinalgOp truncfOp); + +/// Fuse a truncf-only linalg op into its producer. The fused op accumulates +/// in the producer's wide type but yields the truncated type. If inputs are +/// 2D+ (matmul-shaped), replace the fused generic with linalg.matmul of the +/// truncated output type and return that matmul; otherwise return the fused +/// generic. Both `producerOp` and `truncfOp` are erased. +::mlir::FailureOr<::mlir::Operation *> +runFuseTruncfLinalg(::mlir::linalg::LinalgOp producerOp, + ::mlir::linalg::LinalgOp truncfOp, + ::mlir::RewriterBase &rewriter); + +/// Fold affine.apply ops into `forOp`'s lower/upper bounds via +/// xilinx::air::foldAffineApplyIntoLoopBounds. Returns the (possibly new) +/// scf.for, or `forOp` unchanged if the fold did not apply. AIR-only. +::mlir::scf::ForOp runNormalizeForBounds(::mlir::scf::ForOp forOp, + ::mlir::RewriterBase &rewriter); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_CODEGEN_HELPERS_H diff --git a/mlir/include/air/Transform/AIRMatmulPackAndTranspose.h b/mlir/include/air/Transform/AIRMatmulPackAndTranspose.h new file mode 100644 index 000000000..217929eaf --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulPackAndTranspose.h @@ -0,0 +1,31 @@ +//===- AIRMatmulPackAndTranspose.h ------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_PACK_AND_TRANSPOSE_H +#define AIR_MATMUL_PACK_AND_TRANSPOSE_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" + +namespace xilinx { +namespace air { + +mlir::LogicalResult runPackAndTransposeImpl( + mlir::func::FuncOp f, llvm::ArrayRef packSizes, + llvm::ArrayRef lhsOuter, llvm::ArrayRef lhsInner, + llvm::ArrayRef rhsOuter, llvm::ArrayRef rhsInner, + llvm::ArrayRef accOuter, llvm::ArrayRef accInner, + llvm::StringRef packedMatmulMarker, bool doBufferizeL1Output, + int64_t bufferizeL1OutputMemorySpace, mlir::RewriterBase &rewriter); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_PACK_AND_TRANSPOSE_H diff --git a/mlir/include/air/Transform/AIRMatmulTileL3ToL2Copies.h b/mlir/include/air/Transform/AIRMatmulTileL3ToL2Copies.h new file mode 100644 index 000000000..a7d135ca8 --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulTileL3ToL2Copies.h @@ -0,0 +1,32 @@ +//===- AIRMatmulTileL3ToL2Copies.h ------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Free-function body for the former `air-matmul-tile-l3-to-l2-copies` pass. +// Now invoked from `air-matmul-bufferize-output-l2` when its +// `do-tile-l3-to-l2-copies` option is set. +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_TILE_L3_TO_L2_COPIES_H +#define AIR_MATMUL_TILE_L3_TO_L2_COPIES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringRef.h" + +namespace xilinx { +namespace air { + +mlir::LogicalResult +runTileL3ToL2CopiesImpl(mlir::func::FuncOp func, int64_t kL2Tile, + llvm::StringRef copyAMarker = "copy_a_loop", + llvm::StringRef copyBMarker = "copy_b_loop"); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_TILE_L3_TO_L2_COPIES_H diff --git a/mlir/include/air/Transform/AIRMatmulTilePasses.h b/mlir/include/air/Transform/AIRMatmulTilePasses.h new file mode 100644 index 000000000..3bb75d590 --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulTilePasses.h @@ -0,0 +1,57 @@ +//===- AIRMatmulTilePasses.h ------------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Free-function bodies invoked by the air-matmul-codegen orchestrator: +// launch-tile, tile-k-and-fuse-packs, tile-cores, and prologue/epilogue +// tiling. Each drives a discrete tiling step on the packed matmul (and, +// where applicable, fuses the LHS/RHS pack producers into the new loop). +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_TILE_PASSES_H +#define AIR_MATMUL_TILE_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" + +namespace xilinx { +namespace air { + +mlir::LogicalResult +runTileLaunchTileImpl(mlir::func::FuncOp f, llvm::ArrayRef tileSizes, + llvm::StringRef launchTileForallMarker, + mlir::RewriterBase &rewriter); + +mlir::LogicalResult runTileKAndFusePacksImpl( + mlir::func::FuncOp f, int64_t kTileFactor, int64_t kIterIndex, + llvm::StringRef packedMatmulMarker, llvm::StringRef kReductionLoopMarker, + llvm::StringRef lhsPackMarker, llvm::StringRef rhsPackMarker, + llvm::StringRef lhsL2PackMarker, llvm::StringRef rhsL2PackMarker, + mlir::RewriterBase &rewriter); + +mlir::LogicalResult runTileCoresImpl( + mlir::func::FuncOp f, llvm::ArrayRef tileSizes, + llvm::StringRef packedMatmulMarker, llvm::StringRef lhsPackInKMarker, + llvm::StringRef rhsPackInKMarker, llvm::StringRef computeForallMarker, + llvm::StringRef matmulComputeMarker, llvm::StringRef lhsL1PackMarker, + llvm::StringRef rhsL1PackMarker, mlir::RewriterBase &rewriter); + +mlir::LogicalResult runPrologueEpilogueImpl( + mlir::func::FuncOp f, llvm::ArrayRef prologueTileSizes, + llvm::ArrayRef epilogueTileSizes, + llvm::ArrayRef fillIteratorInterchange, + llvm::StringRef initFillMarker, llvm::StringRef prologueForallMarker, + llvm::StringRef epilogueForallMarker, bool hoistStaticAllocFirst, + mlir::RewriterBase &rewriter); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_TILE_PASSES_H diff --git a/mlir/include/air/Transform/AIRMatmulVectorizePasses.h b/mlir/include/air/Transform/AIRMatmulVectorizePasses.h new file mode 100644 index 000000000..123542f01 --- /dev/null +++ b/mlir/include/air/Transform/AIRMatmulVectorizePasses.h @@ -0,0 +1,47 @@ +//===- AIRMatmulVectorizePasses.h -------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Vectorization-prep free functions invoked by the air-matmul-codegen +// orchestrator: tile-for-vectorize and the vec-prep composite (eliminate- +// redundant-transfers, vector-cast-for-emulation, hoist-loop-invariant, +// flatten-for-iter-args, hoist-vector-transfer-pointers, hoist-cast-pairs). +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_MATMUL_VECTORIZE_PASSES_H +#define AIR_MATMUL_VECTORIZE_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace xilinx { +namespace air { + +mlir::LogicalResult runTileForVectorizeImpl( + mlir::func::FuncOp f, llvm::ArrayRef matmulTileSizes, + llvm::ArrayRef matmulUnrollTileSizes, int64_t matmulUnrollFactor, + llvm::ArrayRef fillTileSizes, bool doPostBufferizeCleanupFirst, + mlir::RewriterBase &rewriter); + +mlir::LogicalResult runCodegenVecPrepImpl( + mlir::func::FuncOp f, llvm::StringRef cast1TargetElementType, + llvm::ArrayRef cast1InputIndices, + llvm::ArrayRef cast1OutputIndices, + llvm::StringRef cast2TargetElementType, + llvm::ArrayRef cast2InputIndices, + llvm::ArrayRef cast2OutputIndices, bool doHoistCastPairs, + int64_t hoistCastPairsMaxIterations, mlir::RewriterBase &rewriter); + +} // namespace air +} // namespace xilinx + +#endif // AIR_MATMUL_VECTORIZE_PASSES_H diff --git a/mlir/include/air/Transform/PassDetail.h b/mlir/include/air/Transform/PassDetail.h index 8cabe1b29..cbbcb02c6 100644 --- a/mlir/include/air/Transform/PassDetail.h +++ b/mlir/include/air/Transform/PassDetail.h @@ -50,6 +50,8 @@ namespace air { #define GEN_PASS_DEF_AIRLABELSCFFORLOOPINAIRSEGMENTPATTERN #define GEN_PASS_DEF_AIRSPECIALIZECHANNELWRAPANDSTRIDEPATTERN #define GEN_PASS_DEF_AIRLINALGCODEGEN +#define GEN_PASS_DEF_AIRFOLDUNITEXTENTDIMS +#define GEN_PASS_DEF_AIRMATMULCODEGEN #define GEN_PASS_DEF_AIRLINALGNAMEPASS #define GEN_PASS_DEF_AIRLINALGOPSTATS #define GEN_PASS_DEF_AIRLOOPMERGINGPASS diff --git a/mlir/include/air/Transform/Passes.h b/mlir/include/air/Transform/Passes.h index de8aab84a..5f1f62492 100644 --- a/mlir/include/air/Transform/Passes.h +++ b/mlir/include/air/Transform/Passes.h @@ -24,6 +24,12 @@ #include "air/Transform/AIRLoopMergingPass.h" #include "air/Transform/AIRLoopPermutationPass.h" #include "air/Transform/AIRLowerLinalgTensors.h" +#include "air/Transform/AIRMatmulBufferizationPasses.h" +#include "air/Transform/AIRMatmulCodegen.h" +#include "air/Transform/AIRMatmulPackAndTranspose.h" +#include "air/Transform/AIRMatmulTileL3ToL2Copies.h" +#include "air/Transform/AIRMatmulTilePasses.h" +#include "air/Transform/AIRMatmulVectorizePasses.h" #include "air/Transform/AIRMiscPasses.h" #include "air/Transform/AIRRegularizeLoopPass.h" #include "air/Transform/AIRSplitLaunchForPadding.h" diff --git a/mlir/include/air/Transform/Passes.td b/mlir/include/air/Transform/Passes.td index 5743c8a13..cebadce61 100644 --- a/mlir/include/air/Transform/Passes.td +++ b/mlir/include/air/Transform/Passes.td @@ -1107,6 +1107,215 @@ def AIRSplitLaunchForPadding: Pass<"air-split-launch-for-padding", "ModuleOp"> { ]; } +def AIRMatmulCodegen : Pass<"air-matmul-codegen", "ModuleOp"> { + let summary = "Single public matmul codegen pass. Orchestrates internal " + "phases (launch tile, packs, K-tile, core tile, " + "prologue/epilogue, bufferize-to-alloc, one-shot-bufferize, " + "tile-for-vectorize, vec-prep) in fixed order. Each phase is " + "skipped when its config is empty / zero / disabled."; + let constructor = "xilinx::air::createAIRMatmulCodegenPass()"; + let description = [{ + Orchestrates the matmul codegen pipeline as a single pass. Internal + phase order (each gated by its config; canonicalize/cse runs between + most phases): + + A. tile-launch-tile (launch-tile) + B. pack-and-transpose (l2-pack-sizes + l2-*-perm) + C. bufferize-output-l2 (bufferize-output-l2 + optional pre-steps) + D. pack-and-transpose (l1-pack-sizes + l1-*-perm; L1-output bufferize) + E. tile-k-and-fuse-packs (outer-k-tile-factor) + F. bufferize-l1-inputs into L2 (auto when D ran) + H. tile-cores (core-tile) + I. tile-k-and-fuse-packs (inner-k-tile-factor) + J. bufferize-l1-inputs into L1 (auto when H ran) + K. prologue-epilogue (prologue-tile / epilogue-tile) + L. one-shot-bufferize (one-shot-bufferize) + M. tile-for-vectorize (matmul-vec-tile) + N. vec-prep composite + + Skipping a phase is the natural way to compose subsets: tests using + only the vectorize stages leave A--K empty and one-shot-bufferize=false. + Phase L is gated by the one-shot-bufferize option (default true). Phase N + (vec-prep composite) always runs but its individual steps walk for + vector ops, so it becomes a no-op on pre-vectorize IR; for tests using + only the tile/pack stages it is therefore a cheap no-op. + }]; + let options = [ + // ---- Phase A: launch tile ---- + ListOption<"clLaunchTile", "launch-tile", "int64_t", + "Tile sizes for the outer launch-tile scf.forall. Skipped if " + "empty.", + "llvm::cl::ZeroOrMore">, + + // ---- Phase B: L2 pack ---- + ListOption<"clL2PackSizes", "l2-pack-sizes", "int64_t", + "Per-iterator pack sizes for the L2 pack. Skipped if empty.", + "llvm::cl::ZeroOrMore">, + ListOption<"clL2LhsOuterPerm", "l2-lhs-outer-perm", "int64_t", + "L2 LHS outer-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL2LhsInnerPerm", "l2-lhs-inner-perm", "int64_t", + "L2 LHS inner-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL2RhsOuterPerm", "l2-rhs-outer-perm", "int64_t", + "L2 RHS outer-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL2RhsInnerPerm", "l2-rhs-inner-perm", "int64_t", + "L2 RHS inner-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL2AccOuterPerm", "l2-acc-outer-perm", "int64_t", + "L2 accumulator outer-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL2AccInnerPerm", "l2-acc-inner-perm", "int64_t", + "L2 accumulator inner-dim perm.", "llvm::cl::ZeroOrMore">, + + // ---- Phase C: bufferize output L2 alloc ---- + Option<"clBufferizeOutputL2", "bufferize-output-l2", "bool", + /*default=*/"false", + "Bufferize the matmul accumulator init (linalg.fill) into an L2 " + "allocation.">, + Option<"clBufferizeOutputL2MemorySpace", + "bufferize-output-l2-memory-space", "int64_t", /*default=*/"1", + "Memory space for the L2 accumulator allocation.">, + Option< + "clFuseOutputTruncfFirst", "fuse-output-truncf-first", "bool", + /*default=*/"false", + "Pre-step: fuse a single-truncf linalg.generic consumer of the " + "matmul into the matmul before bufferizing. Used by bf16-out flows.">, + Option<"clTileL3ToL2Copies", "tile-l3-to-l2-copies", "bool", + /*default=*/"false", + "Pre-step: convert memref.copy L3->L2 stagings to linalg.copy and " + "tile each by k-l2-tile. Used by Triton-style flows.">, + Option<"clKL2Tile", "k-l2-tile", "int64_t", /*default=*/"16", + "K-tile size for L3->L2 copies (only when " + "tile-l3-to-l2-copies=true).">, + + // ---- Phase D: L1 pack ---- + ListOption<"clL1PackSizes", "l1-pack-sizes", "int64_t", + "Per-iterator pack sizes for the L1 pack. Skipped if empty. " + "When set, the L1 pack output is also bufferized to L1.", + "llvm::cl::ZeroOrMore">, + ListOption<"clL1LhsOuterPerm", "l1-lhs-outer-perm", "int64_t", + "L1 LHS outer-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL1LhsInnerPerm", "l1-lhs-inner-perm", "int64_t", + "L1 LHS inner-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL1RhsOuterPerm", "l1-rhs-outer-perm", "int64_t", + "L1 RHS outer-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL1RhsInnerPerm", "l1-rhs-inner-perm", "int64_t", + "L1 RHS inner-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL1AccOuterPerm", "l1-acc-outer-perm", "int64_t", + "L1 accumulator outer-dim perm.", "llvm::cl::ZeroOrMore">, + ListOption<"clL1AccInnerPerm", "l1-acc-inner-perm", "int64_t", + "L1 accumulator inner-dim perm.", "llvm::cl::ZeroOrMore">, + Option<"clL1OutputMemorySpace", "l1-output-memory-space", "int64_t", + /*default=*/"2", + "Memory space for the bufferized L1 pack output.">, + Option< + "clBufferizeLastPackOutput", "bufferize-last-pack-output", "bool", + /*default=*/"true", + "Bufferize the LAST pack's output (L1 pack if l1-pack-sizes is set, " + "otherwise the L2 pack) into L1 memory. Set false to leave the " + "pack output as a tensor (e.g. for inspecting raw pack semantics).">, + + // ---- Phase E: outer K-tile ---- + Option<"clOuterKTileFactor", "outer-k-tile-factor", "int64_t", + /*default=*/"0", + "K-tile size for the outer K reduction loop. Skipped if 0.">, + Option<"clOuterKIterIndex", "outer-k-iter-index", "int64_t", + /*default=*/"2", + "K iterator index for the outer K-tile (default 2 = standard " + "post-pack [m,n,k]).">, + + // ---- Phase H: tile cores ---- + ListOption<"clCoreTile", "core-tile", "int64_t", + "Per-iterator tile sizes for the per-core scf.forall. Skipped " + "if empty.", + "llvm::cl::ZeroOrMore">, + + // ---- Phase I: inner K-tile ---- + Option<"clInnerKTileFactor", "inner-k-tile-factor", "int64_t", + /*default=*/"0", + "K-tile size for the inner K reduction loop. Skipped if 0.">, + Option< + "clInnerKIterIndex", "inner-k-iter-index", "int64_t", + /*default=*/"5", + "K iterator index for the inner K-tile (default 5 = two-pack-level " + "inner K position).">, + + // ---- Phase K: prologue/epilogue ---- + ListOption<"clPrologueTile", "prologue-tile", "int64_t", + "Tile sizes for the prologue (fill) forall.", + "llvm::cl::ZeroOrMore">, + ListOption<"clEpilogueTile", "epilogue-tile", "int64_t", + "Tile sizes for the epilogue (unpack) forall.", + "llvm::cl::ZeroOrMore">, + ListOption<"clFillIterPerm", "fill-iter-perm", "int64_t", + "Iterator-permutation vector applied to the generalized fill " + "before tiling. Empty disables interchange.", + "llvm::cl::ZeroOrMore">, + Option<"clHoistStaticAllocFirst", "hoist-static-alloc-first", "bool", + /*default=*/"false", + "Pre-step: hoist statically-bound memref.alloc ops out of nested " + "loops to function entry. Used by the two-pack-level flow.">, + + // ---- Phase L: one-shot bufferize ---- + Option<"clOneShotBufferize", "one-shot-bufferize", "bool", + /*default=*/"false", + "Run upstream one-shot-bufferize (function-boundary, " + "identity-layout) after the tile/pack stages and before the " + "vectorize stages.">, + + // ---- Phase M: tile for vectorize ---- + ListOption<"clMatmulVecTile", "matmul-vec-tile", "int64_t", + "First-level tile sizes for the packed matmul body. Skipped " + "if empty.", + "llvm::cl::ZeroOrMore">, + ListOption<"clMatmulUnrollVecTile", "matmul-unroll-vec-tile", "int64_t", + "Second-level tile sizes (the two innermost loops are " + "unrolled).", + "llvm::cl::ZeroOrMore">, + Option<"clMatmulUnrollFactor", "matmul-unroll-factor", "uint64_t", + /*default=*/"2", + "Unroll factor applied to the two innermost loops.">, + ListOption<"clFillVecTile", "fill-vec-tile", "int64_t", + "Tile sizes for linalg.fill in the vectorize stage.", + "llvm::cl::ZeroOrMore">, + Option< + "clPostBufferizeCleanupFirst", "post-bufferize-cleanup-first", "bool", + /*default=*/"false", + "Pre-step: run post-bufferize cleanup (remove uninitialized " + "copies, eliminate cascade memcpys, sibling-fuse pingpong loops).">, + + // ---- Phase N: vec-prep composite (always runs; no-op on + // pre-vectorize IR when called between tiling phases) ---- + Option<"clVecPrepCast1TargetElementType", + "vec-prep-cast1-target-element-type", "std::string", + /*default=*/"\"\"", + "vec-prep: first vector-cast target element type ('' = skip).">, + ListOption<"clVecPrepCast1InputIndices", "vec-prep-cast1-input-indices", + "int64_t", + "vec-prep: first vector-cast input operand indices.", + "llvm::cl::ZeroOrMore">, + ListOption<"clVecPrepCast1OutputIndices", "vec-prep-cast1-output-indices", + "int64_t", + "vec-prep: first vector-cast output operand indices.", + "llvm::cl::ZeroOrMore">, + Option<"clVecPrepCast2TargetElementType", + "vec-prep-cast2-target-element-type", "std::string", + /*default=*/"\"\"", + "vec-prep: second vector-cast target element type ('' = skip).">, + ListOption<"clVecPrepCast2InputIndices", "vec-prep-cast2-input-indices", + "int64_t", + "vec-prep: second vector-cast input operand indices.", + "llvm::cl::ZeroOrMore">, + ListOption<"clVecPrepCast2OutputIndices", "vec-prep-cast2-output-indices", + "int64_t", + "vec-prep: second vector-cast output operand indices.", + "llvm::cl::ZeroOrMore">, + Option<"clVecPrepHoistCastPairs", "vec-prep-hoist-cast-pairs", "bool", + /*default=*/"false", + "vec-prep: iteratively hoist matched ext/trunc pairs.">, + Option<"clVecPrepHoistCastPairsMaxIterations", + "vec-prep-hoist-cast-pairs-max-iterations", "int64_t", + /*default=*/"32", + "vec-prep: fixed-point cap when vec-prep-hoist-cast-pairs=true.">]; +} + def AIRLoopFusion: Pass<"air-loop-fusion", "func::FuncOp"> { let summary = "Hoist dma ops into perfectly nested loop"; let constructor = "xilinx::air::createAIRLoopFusion()"; diff --git a/mlir/include/air/Util/MatmulCodegenConfig.h b/mlir/include/air/Util/MatmulCodegenConfig.h new file mode 100644 index 000000000..b64d1f1d3 --- /dev/null +++ b/mlir/include/air/Util/MatmulCodegenConfig.h @@ -0,0 +1,95 @@ +//===- MatmulCodegenConfig.h ------------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Carrier attribute + reader/writer helpers for the matmul codegen pipeline. +// External producers (autotuners, future heuristic passes) write the +// attribute on each linalg.matmul (or marker-attributed LinalgOp). The +// air-matmul-codegen orchestrator currently does NOT read this attribute +// (per-phase options are passed explicitly by the caller); this header +// remains so the schema and helpers are available to the future heuristic. +// The attribute is a `DictionaryAttr` named "air.matmul_codegen_config" +// with the following keys (any field may be missing): +// +// tile_l3_l2_k : i64 +// pack_sizes : ArrayAttr (length 3) +// lhs_outer_perm : ArrayAttr (length 2; e.g. [1,0]) +// lhs_inner_perm : ArrayAttr +// rhs_outer_perm : ArrayAttr +// rhs_inner_perm : ArrayAttr +// acc_outer_perm : ArrayAttr +// acc_inner_perm : ArrayAttr +// tile_k_factor : i64 +// tile_cores : ArrayAttr +// prologue_tile : ArrayAttr +// epilogue_tile : ArrayAttr +// fill_iter_perm : ArrayAttr +// vector_tile : ArrayAttr (length 6 for packed matmul) +// vector_unroll_tile: ArrayAttr +// vector_unroll_factor : i64 +// fill_vector_tile : ArrayAttr +// bfp16_emulation : bool +// fuse_output_truncf : bool +// bf16_output_hoist_pairs : bool +// three_herd_prologue_epilogue: bool +// +//===----------------------------------------------------------------------===// + +#ifndef AIR_UTIL_MATMUL_CODEGEN_CONFIG_H +#define AIR_UTIL_MATMUL_CODEGEN_CONFIG_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/SmallVector.h" + +namespace xilinx { +namespace air { + +/// Discardable attribute name on the linalg.matmul (or its packed marker +/// successor) carrying the codegen config dictionary. +inline llvm::StringRef getMatmulCodegenConfigAttrName() { + return "air.matmul_codegen_config"; +} + +/// Find the codegen-config DictionaryAttr in `funcOp`. Looks for the first op +/// in the function carrying `getMatmulCodegenConfigAttrName()`. Returns the +/// dict (possibly empty) on success, std::nullopt if no config is attached. +std::optional<::mlir::DictionaryAttr> +findMatmulCodegenConfig(::mlir::func::FuncOp funcOp); + +/// Helper: extract an `ArrayAttr` field from `cfg` as +/// `SmallVector`. Returns an empty vector if the field is missing or +/// the wrong type. +::llvm::SmallVector getI64Array(::mlir::DictionaryAttr cfg, + ::llvm::StringRef key); + +/// Helper: extract an i64 field from `cfg`. Returns `defaultVal` if missing. +int64_t getI64(::mlir::DictionaryAttr cfg, ::llvm::StringRef key, + int64_t defaultVal); + +/// Helper: extract a bool field from `cfg`. Returns `defaultVal` if missing. +bool getBool(::mlir::DictionaryAttr cfg, ::llvm::StringRef key, + bool defaultVal); + +/// Build (and write) a DictionaryAttr config onto the first linalg.matmul (or +/// op marked `markerName`) in `funcOp`. Existing entries in `dict` overwrite +/// any prior config. Returns true if an op was found and the attribute was +/// written; false otherwise. +bool writeMatmulCodegenConfig(::mlir::func::FuncOp funcOp, + ::mlir::DictionaryAttr dict, + ::llvm::StringRef markerName = ""); + +/// Build a DictionaryAttr from a list of (name, attr) pairs, dropping any +/// entries with null attrs. Convenience wrapper around DictionaryAttr::get. +::mlir::DictionaryAttr +buildMatmulCodegenConfig(::mlir::MLIRContext *ctx, + ::llvm::ArrayRef<::mlir::NamedAttribute> entries); + +} // namespace air +} // namespace xilinx + +#endif // AIR_UTIL_MATMUL_CODEGEN_CONFIG_H diff --git a/mlir/include/air/Util/Util.h b/mlir/include/air/Util/Util.h index e737e99cb..9f59a6abf 100644 --- a/mlir/include/air/Util/Util.h +++ b/mlir/include/air/Util/Util.h @@ -374,6 +374,27 @@ Operation *cloneOpAndOperands( bool opOrAncestorIsDominantOver(Operation *a, Operation *b); +// Walk `root` for the first op (any kind) carrying `attrName` as either an +// inherent or a discardable attribute (uses Operation::hasAttr, which checks +// both). Returns nullptr if no match. +mlir::Operation *findOpWithAttr(mlir::Operation *root, + llvm::StringRef attrName); + +// Walk `root` for the first op of type `OpTy` carrying `attrName` as either +// an inherent or a discardable attribute. Returns null OpTy if no match. +template +OpTy findOpOfTypeWithAttr(mlir::Operation *root, llvm::StringRef attrName) { + OpTy found; + root->walk([&](OpTy op) { + if (op->hasAttr(attrName)) { + found = op; + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + return found; +} + } // namespace air } // namespace xilinx diff --git a/mlir/lib/Transform/AIRLinalgBufferize.cpp b/mlir/lib/Transform/AIRLinalgBufferize.cpp index 935c5ae66..d6a5475fa 100644 --- a/mlir/lib/Transform/AIRLinalgBufferize.cpp +++ b/mlir/lib/Transform/AIRLinalgBufferize.cpp @@ -164,8 +164,7 @@ static bool isUseReplaceableWithSubview(OpOperand &use) { memref::SubViewOp>(user); } -template -std::optional hoistOneStaticallyBoundAllocation( +static std::optional hoistOneStaticallyBoundAllocation( mlir::FunctionOpInterface funcOp, OpBuilder &builder, Location loc, MemRefType allocLikeType, ValueRange dynamicSizes, std::optional alignment, @@ -182,14 +181,9 @@ std::optional hoistOneStaticallyBoundAllocation( OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToStart(&funcOp.getFunctionBody().front()); Value allocation = - AllocLikeOpType::create(builder, loc, allocLikeType, alignmentAttr); - // For memref.alloc, also insert a dealloc in the entry block terminator - // block to preserve semantics (leaks avoided). - if (std::is_same::value) { - builder.setInsertionPoint( - funcOp.getFunctionBody().front().getTerminator()); - memref::DeallocOp::create(builder, loc, allocation); - } + memref::AllocOp::create(builder, loc, allocLikeType, alignmentAttr); + builder.setInsertionPoint(funcOp.getFunctionBody().front().getTerminator()); + memref::DeallocOp::create(builder, loc, allocation); return allocation; } @@ -225,7 +219,7 @@ std::optional hoistOneStaticallyBoundAllocation( dispatchIndexOpFoldResults(allocSizes, dynamicSizes, staticShape); auto allocationType = allocLikeType.clone(staticShape); - allocation = AllocLikeOpType::create(builder, loc, allocationType, + allocation = memref::AllocOp::create(builder, loc, allocationType, dynamicSizes, alignmentAttr); } @@ -246,54 +240,52 @@ std::optional hoistOneStaticallyBoundAllocation( } // As above, insert a dealloc at function end. - if (std::is_same::value) { - builder.setInsertionPoint(funcOp.getFunctionBody().front().getTerminator()); - memref::DeallocOp::create(builder, loc, allocation); - } + builder.setInsertionPoint(funcOp.getFunctionBody().front().getTerminator()); + memref::DeallocOp::create(builder, loc, allocation); return subviewOp; } -template -std::optional hoistOneStaticallyBoundAllocation( +static std::optional hoistOneStaticallyBoundAllocation( mlir::FunctionOpInterface funcOp, OpBuilder &builder, - AllocLikeOpType allocLikeOp, + memref::AllocOp allocLikeOp, std::optional vscaleRange) { // Convenience overload: set insertion point to the original alloc-like op // and forward its properties to the main hoisting routine. OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(allocLikeOp); - return hoistOneStaticallyBoundAllocation( + return hoistOneStaticallyBoundAllocation( funcOp, builder, allocLikeOp.getLoc(), allocLikeOp.getType(), allocLikeOp.getDynamicSizes(), allocLikeOp.getAlignment(), vscaleRange); } -template -void hoistStaticallyBoundAllocationsInFunc( - RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, - std::optional vscaleRange = std::nullopt) { - SmallVector allocLikeOps; +namespace xilinx { +namespace air { + +void hoistStaticAllocsInFunc(RewriterBase &rewriter, + mlir::FunctionOpInterface funcOp) { + SmallVector allocOps; - // Collect candidate alloc-like ops that are not already in the entry block - // and whose uses are safe to rewrite (or have no dynamic sizes). - funcOp.walk([&](AllocLikeOpType allocLikeOp) { - if (allocLikeOp->getBlock() == &funcOp.getFunctionBody().front()) + // Collect candidate allocs that are not already in the entry block and whose + // uses are safe to rewrite (or have no dynamic sizes). + funcOp.walk([&](memref::AllocOp allocOp) { + if (allocOp->getBlock() == &funcOp.getFunctionBody().front()) return; - if (allocLikeOp.getDynamicSizes().empty()) { - allocLikeOps.push_back(allocLikeOp); + if (allocOp.getDynamicSizes().empty()) { + allocOps.push_back(allocOp); return; } // All uses must tolerate replacement by a subview. - if (llvm::all_of(allocLikeOp->getUses(), [](OpOperand &use) { + if (llvm::all_of(allocOp->getUses(), [](OpOperand &use) { return isUseReplaceableWithSubview(use); })) { - allocLikeOps.push_back(allocLikeOp); + allocOps.push_back(allocOp); return; } }); // Hoist each candidate and replace all uses with the hoisted value. - for (auto allocLikeOp : allocLikeOps) { + for (auto allocLikeOp : allocOps) { // Track and remove any deallocs tied to the original allocation; the new // hoisted allocation installs its own dealloc in the entry block. SmallVector deallocOps; @@ -311,7 +303,7 @@ void hoistStaticallyBoundAllocationsInFunc( llvm::dbgs() << " num Uses : " << numUses; }); std::optional replacement = hoistOneStaticallyBoundAllocation( - funcOp, rewriter, allocLikeOp, vscaleRange); + funcOp, rewriter, allocLikeOp, /*vscaleRange=*/std::nullopt); if (!replacement) continue; LLVM_DEBUG({ @@ -326,14 +318,14 @@ void hoistStaticallyBoundAllocationsInFunc( } } +} // namespace air +} // namespace xilinx + DiagnosedSilenceableFailure transform::AIRHoistStaticAllocOp::applyToOne( transform::TransformRewriter &rewriter, mlir::FunctionOpInterface target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - // Apply the hoisting pass to all memref.alloc ops in the target function. - // If more alloc-like ops should be supported, template parameterization - // allows calling this routine for those as well. - hoistStaticallyBoundAllocationsInFunc(rewriter, target); + xilinx::air::hoistStaticAllocsInFunc(rewriter, target); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Transform/AIRLinalgCodegen.cpp b/mlir/lib/Transform/AIRLinalgCodegen.cpp index f51fb116b..301f8440c 100644 --- a/mlir/lib/Transform/AIRLinalgCodegen.cpp +++ b/mlir/lib/Transform/AIRLinalgCodegen.cpp @@ -9,6 +9,7 @@ #include "air/Transform/AIRLinalgCodegen.h" #include "air/Dialect/AIR/AIRDialect.h" #include "air/Dialect/AIR/AIRTransformOps.h" +#include "air/Transform/AIRMatmulCodegenHelpers.h" #if AIR_ENABLE_AIE #include "air/Transform/AIRDependencyScheduleOpt.h" #endif @@ -2510,41 +2511,34 @@ DiagnosedSilenceableFailure transform::FuseIntoContainingMemrefOp::apply( SmallVector fusedOps; SmallVector producerOps = llvm::to_vector(state.getPayloadOps(getProducerOp())); - // If nothing to fuse, propagate success. if (producerOps.empty()) { results.set(llvm::cast(getFusedOp()), SmallVector{}); return DiagnosedSilenceableFailure::success(); } - if (producerOps.size() != 1) { + if (producerOps.size() != 1) return emitDefiniteFailure() << "requires exactly one producer_op handle (got " << producerOps.size() << ")"; - } Operation *producerOp = producerOps.front(); SmallVector containingOps = llvm::to_vector(state.getPayloadOps(getContainingOp())); - if (containingOps.size() != 1) { + if (containingOps.size() != 1) return emitDefiniteFailure() << "requires exactly one containing_op handle (got " << containingOps.size() << ")"; - } Operation *containingOp = containingOps.front(); - linalg::LinalgOp producerLinalgOp = - dyn_cast_if_present(producerOp); - if (!producerLinalgOp) { + auto producerLinalgOp = dyn_cast_if_present(producerOp); + if (!producerLinalgOp) return emitDefiniteFailure() << "requires producer_op to be LinalgOp"; - } - if (producerLinalgOp.getNumDpsInits() != 1) { + if (producerLinalgOp.getNumDpsInits() != 1) return emitDefiniteFailure() << "requires producer_op to have exactly one init operand (got " << producerLinalgOp.getNumDpsInits() << ")"; - } auto initOperand = producerLinalgOp.getDpsInits()[0]; - // The containing op may be a user of producerOp: use isAncestor. int64_t numUsesInContainingOp = llvm::count_if(initOperand.getUsers(), [&](Operation *op) { return containingOp->isAncestor(op); @@ -2556,22 +2550,17 @@ DiagnosedSilenceableFailure transform::FuseIntoContainingMemrefOp::apply( return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - // Default diagnostic, to be complemented with more failure information. - Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); - diag << "could not fuse " << *producerOp << " into " << *containingOp; - - Operation *tiled = - tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); + Operation *tiled = xilinx::air::runFuseIntoContainingMemref( + producerOp, containingOp, rewriter); if (tiled) { - LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n" - << *containingOp); fusedOps.push_back(tiled); rewriter.eraseOp(producerOp); - results.set(llvm::cast(getFusedOp()), fusedOps); return DiagnosedSilenceableFailure::success(); } + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); + diag << "could not fuse " << *producerOp << " into " << *containingOp; results.set(llvm::cast(getFusedOp()), ArrayRef()); return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } @@ -2580,144 +2569,6 @@ DiagnosedSilenceableFailure transform::FuseIntoContainingMemrefOp::apply( // HoistLoopInvariantTransfersOp / HoistAllAccumulatorTransfersOp //===----------------------------------------------------------------------===// -// Forward declaration (defined in EliminateRedundantVectorTransfersOp section) -static bool areEquivalentIndices(Value idx1, Value idx2); - -/// Check if a value depends on the given loop induction variable -static bool dependsOnLoopIV(Value val, Value loopIV) { - if (val == loopIV) - return true; - - // Check if the value is defined by an affine.apply that uses the loop IV - if (auto affineOp = val.getDefiningOp()) { - for (Value operand : affineOp.getMapOperands()) { - if (dependsOnLoopIV(operand, loopIV)) - return true; - } - } - - // Check for arithmetic operations - if (auto defOp = val.getDefiningOp()) { - for (Value operand : defOp->getOperands()) { - if (dependsOnLoopIV(operand, loopIV)) - return true; - } - } - - return false; -} - -/// Recursively clone an operation and its operands, using current insertion -/// point. Only clones operations that are inside the loop being hoisted from. -static Value cloneOpAndOperands(Operation *op, Value loopIV, scf::ForOp loopOp, - RewriterBase &rewriter, IRMapping &mapping) { - // If already mapped, return the mapped value - if (!op->getResults().empty()) - if (mapping.contains(op->getResult(0))) - return mapping.lookup(op->getResult(0)); - - // Clone operand-producing operations first - for (Value operand : op->getOperands()) { - if (operand == loopIV) - continue; // Can't clone loop IV - - if (mapping.contains(operand)) - continue; // Already cloned - - // BlockArguments from enclosing loops are still in scope after hoisting - - // use directly - if (isa(operand) && operand != loopIV) - continue; // BlockArguments from outer loops are still accessible - - Operation *defOp = operand.getDefiningOp(); - if (!defOp) - continue; - - // If the defining operation is outside the loop we're hoisting from, - // it's already in scope - use directly without cloning - if (!loopOp->isAncestor(defOp)) - continue; - - if (!dependsOnLoopIV(operand, loopIV)) { - Value clonedOperand = - cloneOpAndOperands(defOp, loopIV, loopOp, rewriter, mapping); - mapping.map(operand, clonedOperand); - } - } - - // Clone this operation at the current insertion point (don't reset it!) - Operation *cloned = rewriter.clone(*op, mapping); - if (cloned->getResults().empty()) - return nullptr; - else - return cloned->getResult(0); -} - -/// Hoist a single transfer read/write pair out of a loop. The read is cloned -/// before the loop, the write is cloned after the loop, and an iter_arg is -/// added to carry the accumulator value through the loop body. -/// Returns the new ForOp on success. -static FailureOr -hoistTransferPairFromLoop(vector::TransferReadOp readOp, - vector::TransferWriteOp writeOp, scf::ForOp loopOp, - RewriterBase &rewriter) { - Value loopIV = loopOp.getInductionVar(); - - // Clone the read and its operands before the loop - rewriter.setInsertionPoint(loopOp); - IRMapping readMapping; - Value clonedReadResult = - cloneOpAndOperands(readOp, loopIV, loopOp, rewriter, readMapping); - - // Capture writeVector before replaceWithAdditionalYields - Value writeVector = writeOp.getVector(); - auto yieldValuesFn = - [&](OpBuilder &b, Location loc, - ArrayRef newBbArgs) -> SmallVector { - BlockArgument readIterArg = newBbArgs.back(); - rewriter.replaceAllUsesWith(readOp.getResult(), readIterArg); - SmallVector yieldValues; - yieldValues.push_back(writeVector); - return yieldValues; - }; - - FailureOr newLoopResult = - cast(loopOp.getOperation()) - .replaceWithAdditionalYields(rewriter, ValueRange{clonedReadResult}, - true, yieldValuesFn); - if (failed(newLoopResult)) - return failure(); - - auto newLoop = cast(newLoopResult->getOperation()); - rewriter.eraseOp(readOp); - - // Clone the write operation after the loop using the yielded value - Value valueToWrite = newLoop.getResults().back(); - IRMapping writeMapping; - writeMapping.map(writeVector, valueToWrite); - - rewriter.setInsertionPointAfter(newLoop); - - for (Value index : writeOp.getIndices()) { - Operation *defOp = index.getDefiningOp(); - if (!defOp || dependsOnLoopIV(index, loopIV)) - continue; - if (!newLoop->isProperAncestor(defOp)) - continue; - if (!writeMapping.contains(index)) { - Value clonedIndex = - cloneOpAndOperands(defOp, loopIV, newLoop, rewriter, writeMapping); - if (clonedIndex) - writeMapping.map(index, clonedIndex); - } - } - - rewriter.clone(*writeOp.getOperation(), writeMapping); - rewriter.eraseOp(writeOp); - - return newLoop; -} - DiagnosedSilenceableFailure transform::HoistLoopInvariantTransfersOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { @@ -2732,86 +2583,16 @@ DiagnosedSilenceableFailure transform::HoistLoopInvariantTransfersOp::apply( << "requires exactly one scope_op and one loop_op handle"; } - auto scopeOp = scopeOps[0]; auto loopOp = dyn_cast_if_present(loopOps[0]); - if (!loopOp) { + if (!loopOp) return emitDefiniteFailure() << "loop_op must be an scf.for"; - } - - if (!scopeOp->isProperAncestor(loopOp)) { - return emitDefiniteFailure() << "loop must be inside the scope operation"; - } - - // Iteratively discover and hoist one loop-invariant transfer pair at a time. - // After each hoist, the loop is replaced with a new loop, so we re-discover - // pairs in the new loop to avoid stale Operation* pointers. - scf::ForOp currentLoop = loopOp; - while (true) { - Value loopIV = currentLoop.getInductionVar(); - - // Find one loop-invariant write and its paired read - vector::TransferWriteOp foundWrite = nullptr; - vector::TransferReadOp foundRead = nullptr; - - currentLoop->walk([&](vector::TransferWriteOp writeOp) { - if (foundWrite) - return; - if (writeOp->getParentOfType() != currentLoop) - return; + auto newLoop = xilinx::air::runHoistLoopInvariantTransfers(scopeOps[0], + loopOp, rewriter); + if (failed(newLoop)) + return emitDefiniteFailure() << "hoist-loop-invariant-transfers failed"; - // Check all write indices are loop-invariant - bool allInvariant = true; - for (Value index : writeOp.getIndices()) { - if (dependsOnLoopIV(index, loopIV)) { - allInvariant = false; - break; - } - } - if (!allInvariant) - return; - - // Find paired read with same memref and matching loop-invariant indices - currentLoop->walk([&](vector::TransferReadOp readOp) { - if (foundRead) - return; - if (readOp->getParentOfType() != currentLoop) - return; - if (readOp.getBase() != writeOp.getBase()) - return; - - for (Value index : readOp.getIndices()) { - if (dependsOnLoopIV(index, loopIV)) - return; - } - - if (readOp.getIndices().size() != writeOp.getIndices().size()) - return; - for (auto [ri, wi] : - llvm::zip(readOp.getIndices(), writeOp.getIndices())) { - if (!areEquivalentIndices(ri, wi)) - return; - } - - foundRead = readOp; - }); - - if (foundRead) - foundWrite = writeOp; - }); - - if (!foundWrite || !foundRead) - break; // No more pairs to hoist - - FailureOr newLoop = - hoistTransferPairFromLoop(foundRead, foundWrite, currentLoop, rewriter); - if (failed(newLoop)) { - return emitDefiniteFailure() << "failed to hoist transfer pair"; - } - currentLoop = *newLoop; - } - - SmallVector resultOps = {currentLoop.getOperation()}; + SmallVector resultOps = {newLoop->getOperation()}; results.set(llvm::cast(getResult()), resultOps); return DiagnosedSilenceableFailure::success(); } @@ -3044,26 +2825,13 @@ DiagnosedSilenceableFailure transform::RemoveUninitializedCopyOp::apply( } SmallVector transformedOps; - for (Operation *target : targets) { auto funcOp = dyn_cast_if_present(target); - if (!funcOp) { + if (!funcOp) return emitDefiniteFailure() << "target must be a func.func operation"; - } - - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet patterns(ctx); - - // Apply unified copy optimization pattern that: - // 1. Removes copy operations with uninitialized sources - // 2. Replaces copy operations with fill when source is only filled - patterns.insert, - OptimizeCopyOpPattern>(ctx); - (void)applyPatternsGreedily(funcOp, std::move(patterns)); - + (void)xilinx::air::runRemoveUninitializedCopy(funcOp); transformedOps.push_back(funcOp); } - results.set(llvm::cast(getResult()), transformedOps); return DiagnosedSilenceableFailure::success(); } @@ -3085,20 +2853,10 @@ DiagnosedSilenceableFailure transform::EliminateCascadeMemcpyOp::apply( } SmallVector transformedOps; - for (Operation *target : targets) { - MLIRContext *ctx = target->getContext(); - RewritePatternSet patterns(ctx); - - // Use the existing EliminateIntermediateMemrefPattern - patterns.insert(ctx); - - // Apply the pattern to eliminate cascade memcpy operations - (void)applyPatternsGreedily(target, std::move(patterns)); - + (void)xilinx::air::runEliminateCascadeMemcpy(target); transformedOps.push_back(target); } - results.set(llvm::cast(getResult()), transformedOps); return DiagnosedSilenceableFailure::success(); } @@ -3120,20 +2878,10 @@ DiagnosedSilenceableFailure transform::ConvertMemrefCopyToLinalgCopyOp::apply( } SmallVector transformedOps; - for (Operation *target : targets) { - MLIRContext *ctx = target->getContext(); - RewritePatternSet patterns(ctx); - - // Use the ConvertMemrefCopyToLinalgCopyPattern - patterns.insert(ctx); - - // Apply the pattern to convert memref.copy to linalg.copy operations - (void)applyPatternsGreedily(target, std::move(patterns)); - + (void)xilinx::air::runConvertMemrefCopyToLinalgCopy(target); transformedOps.push_back(target); } - results.set(llvm::cast(getResult()), transformedOps); return DiagnosedSilenceableFailure::success(); } @@ -4087,37 +3835,10 @@ transform::FuseTruncfLinalgOp::apply(transform::TransformRewriter &rewriter, "is consumed by truncf_op"; } - // Perform the fusion: create a fused generic, then replace it with a - // linalg.matmul that has the fused output type (bf16). LLVM 23's - // specialize rejects generics with output casts, so we bypass it by - // directly creating the matmul with the fused type. - FailureOr fusedOp = - fuseTruncfIntoProducer(rewriter, producerLinalgOp, truncfLinalgOp); - if (failed(fusedOp)) { + FailureOr fusedOp = xilinx::air::runFuseTruncfLinalg( + producerLinalgOp, truncfLinalgOp, rewriter); + if (failed(fusedOp)) return emitDefiniteFailure() << "failed to fuse the operations"; - } - - // LLVM 23: specialize rejects generics with output casts (truncf→yield). - // If the fused op has 2D+ inputs (matmul-compatible), replace with a - // linalg.matmul directly, bypassing specialize. The matmul body auto- - // generates in the output element type (bf16), and Phase 12 adds - // extf/truncf pairs for f32 accumulation during vectorization. - auto inputType = - dyn_cast(fusedOp->getDpsInputs()[0].getType()); - if (inputType && inputType.getRank() >= 2) { - rewriter.setInsertionPoint(*fusedOp); - auto matmulOp = linalg::MatmulOp::create( - rewriter, fusedOp->getLoc(), fusedOp->getResultTypes(), - ValueRange{fusedOp->getDpsInputs()[0], fusedOp->getDpsInputs()[1]}, - ValueRange{fusedOp->getDpsInits()[0]}); - rewriter.replaceOp(*fusedOp, matmulOp->getResults()); - - SmallVector resultOps = {matmulOp.getOperation()}; - results.set(llvm::cast(getFusedOp()), resultOps); - return DiagnosedSilenceableFailure::success(); - } - - // For non-matmul cases (1D, etc.), return the generic as-is. SmallVector resultOps = {*fusedOp}; results.set(llvm::cast(getFusedOp()), resultOps); return DiagnosedSilenceableFailure::success(); @@ -4135,15 +3856,6 @@ void transform::FuseTruncfLinalgOp::getEffects( // VectorTypeCastOp //===----------------------------------------------------------------------===// -/// Calculate the total number of elements in a vector type -static int64_t getVectorNumElements(VectorType vecType) { - int64_t numElements = 1; - for (int64_t dim : vecType.getShape()) { - numElements *= dim; - } - return numElements; -} - /// Helper function to create cast operations for both scalar and vector types static Value createTypeCast(OpBuilder &builder, Location loc, Value input, Type targetElementType, bool isExtension) { @@ -4227,7 +3939,7 @@ static FailureOr applyVectorTypeCastToOp( for (auto [idx, operand] : llvm::enumerate(op->getOperands())) { if (auto vectorType = dyn_cast_if_present(operand.getType())) { hasAnyVectors = true; - if (getVectorNumElements(vectorType) != 1) { + if (vectorType.getNumElements() != 1) { allVectorsAreSingleElement = false; } } @@ -4236,7 +3948,7 @@ static FailureOr applyVectorTypeCastToOp( for (auto [idx, result] : llvm::enumerate(op->getResults())) { if (auto vectorType = dyn_cast_if_present(result.getType())) { hasAnyVectors = true; - if (getVectorNumElements(vectorType) != 1) { + if (vectorType.getNumElements() != 1) { allVectorsAreSingleElement = false; } } @@ -4387,6 +4099,55 @@ static FailureOr applyVectorTypeCastToOp( return newOp; } +// Free C++ entry point used by both transform.air.vector_type_cast and the +// air-vector-cast-for-emulation pass. +LogicalResult xilinx::air::runVectorTypeCastOnTarget( + Operation *target, Type targetElementType, ArrayRef inputIndices, + ArrayRef outputIndices, RewriterBase &rewriter) { + bool hasVectorTypes = false; + for (Value operand : target->getOperands()) + if (isa(operand.getType())) { + hasVectorTypes = true; + break; + } + if (!hasVectorTypes) { + for (Value result : target->getResults()) + if (isa(result.getType())) { + hasVectorTypes = true; + break; + } + } + if (!hasVectorTypes) + return target->emitError("target operation must have vector operands or " + "results, but operation '") + << target->getName() << "' operates on scalar types"; + + bool needsTransformation = false; + for (Value operand : target->getOperands()) + if (auto vt = dyn_cast_if_present(operand.getType())) + if (vt.getElementType() != targetElementType) { + needsTransformation = true; + break; + } + if (!needsTransformation) { + for (Value result : target->getResults()) + if (auto vt = dyn_cast_if_present(result.getType())) + if (vt.getElementType() != targetElementType) { + needsTransformation = true; + break; + } + } + if (!needsTransformation) + return success(); + + // applyVectorTypeCastToOp may return failure for "skip" cases (e.g. all + // vectors size-1). Treat that as success-with-no-change. + SmallVector in(inputIndices.begin(), inputIndices.end()); + SmallVector out(outputIndices.begin(), outputIndices.end()); + (void)applyVectorTypeCastToOp(target, targetElementType, in, out, rewriter); + return success(); +} + DiagnosedSilenceableFailure transform::VectorTypeCastOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -4401,84 +4162,19 @@ transform::VectorTypeCastOp::apply(transform::TransformRewriter &rewriter, } Type targetElementType = getTargetElementType(); - - // Extract input and output indices from attributes SmallVector inputIndicesToCast = extractFromIntegerArrayAttr(getInputIndices()); SmallVector outputIndicesToCast = extractFromIntegerArrayAttr(getOutputIndices()); SmallVector transformedOps; - for (Operation *target : targets) { - // Check if this operation has vector operands or results - bool hasVectorTypes = false; - for (Value operand : target->getOperands()) { - if (isa(operand.getType())) { - hasVectorTypes = true; - break; - } - } - if (!hasVectorTypes) { - for (Value result : target->getResults()) { - if (isa(result.getType())) { - hasVectorTypes = true; - break; - } - } - } - - if (!hasVectorTypes) { - return emitDefiniteFailure() - << "target operation must have vector operands or results, but " - "operation '" - << target->getName() - << "' operates on scalar types. Vector type casting " - << "can only be applied to operations that work with vector " - "types."; - } - - // Check if this operation has vector types that need casting - bool needsTransformation = false; - for (Value operand : target->getOperands()) { - if (auto vectorType = - dyn_cast_if_present(operand.getType())) { - if (vectorType.getElementType() != targetElementType) { - needsTransformation = true; - break; - } - } - } - if (!needsTransformation) { - for (Value result : target->getResults()) { - if (auto vectorType = - dyn_cast_if_present(result.getType())) { - if (vectorType.getElementType() != targetElementType) { - needsTransformation = true; - break; - } - } - } - } - - if (needsTransformation) { - // Apply transformation directly to the target operation with selective - // casting - FailureOr castedOpOnVector = - applyVectorTypeCastToOp(target, targetElementType, inputIndicesToCast, - outputIndicesToCast, rewriter); - if (failed(castedOpOnVector)) { - // Operation was skipped (e.g., all vectors are single-element) - // This is not an error, just add the original operation unchanged - transformedOps.push_back(target); - } else { - transformedOps.push_back(*castedOpOnVector); - } - } else { - transformedOps.push_back(target); - } + if (failed(xilinx::air::runVectorTypeCastOnTarget( + target, targetElementType, inputIndicesToCast, outputIndicesToCast, + rewriter))) + return emitDefiniteFailure() << "vector_type_cast failed"; + transformedOps.push_back(target); } - results.set(llvm::cast(getResult()), transformedOps); return DiagnosedSilenceableFailure::success(); } @@ -4487,125 +4183,6 @@ transform::VectorTypeCastOp::apply(transform::TransformRewriter &rewriter, // EliminateRedundantVectorTransfersOp //===----------------------------------------------------------------------===// -/// Check if two values are semantically equivalent indices -static bool areEquivalentIndices(Value idx1, Value idx2) { - // Direct SSA value equality - if (idx1 == idx2) - return true; - - // Check if both are results of affine.apply with the same map and operands - auto affineOp1 = idx1.getDefiningOp(); - auto affineOp2 = idx2.getDefiningOp(); - - if (affineOp1 && affineOp2) { - // Check if they use the same affine map - if (affineOp1.getAffineMap() != affineOp2.getAffineMap()) - return false; - - // Check if they have the same number of operands - if (affineOp1.getMapOperands().size() != affineOp2.getMapOperands().size()) - return false; - - // Check if all operands are identical - for (auto [op1, op2] : - llvm::zip(affineOp1.getMapOperands(), affineOp2.getMapOperands())) { - if (op1 != op2) - return false; - } - - return true; - } - - // Check if both are constants with the same value - auto constOp1 = idx1.getDefiningOp(); - auto constOp2 = idx2.getDefiningOp(); - - if (constOp1 && constOp2) { - return constOp1.value() == constOp2.value(); - } - - return false; -} - -/// Check if two vector.transfer_read operations read from the same location -static bool areIdenticalReads(vector::TransferReadOp read1, - vector::TransferReadOp read2) { - // Check if they read from the same memref - if (read1.getBase() != read2.getBase()) - return false; - - // Check if they have the same number of indices - if (read1.getIndices().size() != read2.getIndices().size()) - return false; - - // Check if all indices are semantically equivalent - for (auto [idx1, idx2] : llvm::zip(read1.getIndices(), read2.getIndices())) { - if (!areEquivalentIndices(idx1, idx2)) - return false; - } - - // Check if they have the same result type - auto vec1Ty = llvm::cast(read1.getVector().getType()); - auto vec2Ty = llvm::cast(read2.getVector().getType()); - if (vec1Ty != vec2Ty) - return false; - - return true; -} - -/// Check if there are any writes to the memref between two operations -static bool hasWritesBetweenReads(vector::TransferReadOp firstRead, - vector::TransferReadOp secondRead) { - Value sourceMemref = firstRead.getBase(); - - // Get the block containing both reads - Block *block = firstRead->getBlock(); - if (block != secondRead->getBlock()) - return true; // Conservative: assume writes if in different blocks - - // Find the operations between the two reads - auto firstIt = firstRead->getIterator(); - auto secondIt = secondRead->getIterator(); - - // Iterate from first read to second read - for (auto it = ++firstIt; it != secondIt; ++it) { - Operation *op = &(*it); - - // Check if this operation writes to the source memref - auto memInterface = dyn_cast_if_present(op); - if (!memInterface) { - // Conservative: if we can't determine effects, assume it might write - if (!op->hasTrait()) - continue; - return true; - } - - SmallVector effects; - memInterface.getEffects(effects); - - for (auto &effect : effects) { - if (!isa(effect.getEffect())) - continue; - - Value effectValue = effect.getValue(); - if (!effectValue) - return true; // Unknown write target, be conservative - - // Check if the write is to the same memref or a view of it - if (effectValue == sourceMemref) - return true; - - // Check if the effect value is derived from the same memref - if (auto subview = effectValue.getDefiningOp()) { - if (subview.getSource() == sourceMemref) - return true; - } - } - } - - return false; -} - DiagnosedSilenceableFailure transform::EliminateRedundantVectorTransfersOp::apply( transform::TransformRewriter &rewriter, @@ -4621,49 +4198,11 @@ transform::EliminateRedundantVectorTransfersOp::apply( SmallVector transformedOps; int eliminatedCount = 0; - for (Operation *target : targets) { - // Collect all vector.transfer_read operations in this target - SmallVector transferReads; - target->walk([&](vector::TransferReadOp readOp) { - transferReads.push_back(readOp); - }); - - // Track which reads have been eliminated - llvm::SmallDenseSet eliminated; - - // Compare each pair of reads - for (size_t i = 0; i < transferReads.size(); ++i) { - if (eliminated.contains(transferReads[i])) - continue; - - for (size_t j = i + 1; j < transferReads.size(); ++j) { - if (eliminated.contains(transferReads[j])) - continue; - - vector::TransferReadOp firstRead = transferReads[i]; - vector::TransferReadOp secondRead = transferReads[j]; - - // Check if the reads are identical - if (!areIdenticalReads(firstRead, secondRead)) - continue; - - // Check if there are writes between them - if (hasWritesBetweenReads(firstRead, secondRead)) - continue; - - // Replace the second read with the result of the first read - rewriter.replaceAllUsesWith(secondRead.getResult(), - firstRead.getResult()); - rewriter.eraseOp(secondRead); - eliminated.insert(secondRead); - eliminatedCount++; - } - } - + eliminatedCount += + xilinx::air::runEliminateRedundantVectorTransfers(target, rewriter); transformedOps.push_back(target); } - if (eliminatedCount > 0) { LLVM_DEBUG(llvm::dbgs() << "Eliminated " << eliminatedCount << " redundant vector.transfer_read operations\n"); @@ -4691,141 +4230,14 @@ transform::FlattenForIterArgsOp::apply(transform::TransformRewriter &rewriter, } SmallVector transformedOps; - for (Operation *target : targets) { auto forOp = dyn_cast_if_present(target); - if (!forOp) { + if (!forOp) return emitDefiniteFailure() << "target must be an scf.for operation"; - } - - Location loc = forOp.getLoc(); - - // Collect vector-typed iter_args - SmallVector vectorIterArgIndices; - SmallVector originalVectorTypes; - SmallVector flattenedVectorTypes; - - for (auto [idx, iterArg] : llvm::enumerate(forOp.getInitArgs())) { - if (auto vecType = dyn_cast_if_present(iterArg.getType())) { - vectorIterArgIndices.push_back(idx); - originalVectorTypes.push_back(vecType); - - // Create flattened vector type - int64_t numElements = getVectorNumElements(vecType); - VectorType flatType = - VectorType::get({numElements}, vecType.getElementType()); - flattenedVectorTypes.push_back(flatType); - } - } - - // If no vector iter_args, nothing to do - if (vectorIterArgIndices.empty()) { - transformedOps.push_back(target); - continue; - } - - // Step 1: Insert vector.shape_cast operations before the loop to flatten - // init values - rewriter.setInsertionPoint(forOp); - SmallVector newInitArgs(forOp.getInitArgs().begin(), - forOp.getInitArgs().end()); - - for (auto [idx, vecIdx] : llvm::enumerate(vectorIterArgIndices)) { - Value initArg = forOp.getInitArgs()[vecIdx]; - auto shapeCast = vector::ShapeCastOp::create( - rewriter, loc, flattenedVectorTypes[idx], initArg); - newInitArgs[vecIdx] = shapeCast.getResult(); - } - - // Step 2: Create new result types (flattened for vector types) - SmallVector newResultTypes; - for (auto [idx, resultType] : llvm::enumerate(forOp.getResultTypes())) { - auto it = llvm::find(vectorIterArgIndices, idx); - if (it != vectorIterArgIndices.end()) { - size_t vecIdx = std::distance(vectorIterArgIndices.begin(), it); - newResultTypes.push_back(flattenedVectorTypes[vecIdx]); - } else { - newResultTypes.push_back(resultType); - } - } - - // Step 3: Create new scf.for with flattened iter_args - auto newForOp = - scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newInitArgs); - - // Step 4: Clone the loop body and insert shape_cast operations - Block *oldBody = forOp.getBody(); - Block *newBody = newForOp.getBody(); - - rewriter.setInsertionPointToStart(newBody); - IRMapping mapping; - - // Map the induction variable - mapping.map(oldBody->getArgument(0), newBody->getArgument(0)); - - // For vector iter_args, insert shape_cast to convert back to original shape - for (auto [idx, vecIdx] : llvm::enumerate(vectorIterArgIndices)) { - BlockArgument newArg = newBody->getArgument(vecIdx + 1); - auto shapeCast = vector::ShapeCastOp::create( - rewriter, loc, originalVectorTypes[idx], newArg); - mapping.map(oldBody->getArgument(vecIdx + 1), shapeCast.getResult()); - } - - // Map non-vector iter_args directly - for (auto [idx, arg] : - llvm::enumerate(oldBody->getArguments().drop_front(1))) { - if (llvm::find(vectorIterArgIndices, idx) == vectorIterArgIndices.end()) { - mapping.map(arg, newBody->getArgument(idx + 1)); - } - } - - // Clone operations from old body (except the terminator) - for (Operation &op : oldBody->without_terminator()) { - rewriter.clone(op, mapping); - } - - // Step 5: Handle the yield operation - auto oldYield = cast(oldBody->getTerminator()); - SmallVector newYieldOperands; - - for (auto [idx, yieldValue] : llvm::enumerate(oldYield.getOperands())) { - auto it = llvm::find(vectorIterArgIndices, idx); - if (it != vectorIterArgIndices.end()) { - // Flatten the yielded vector value - size_t vecIdx = std::distance(vectorIterArgIndices.begin(), it); - Value mappedValue = mapping.lookup(yieldValue); - auto shapeCast = vector::ShapeCastOp::create( - rewriter, loc, flattenedVectorTypes[vecIdx], mappedValue); - newYieldOperands.push_back(shapeCast.getResult()); - } else { - newYieldOperands.push_back(mapping.lookup(yieldValue)); - } - } - - scf::YieldOp::create(rewriter, loc, newYieldOperands); - - // Step 6: Insert shape_cast operations after the loop to convert results - // back - rewriter.setInsertionPointAfter(newForOp); - SmallVector finalResults; - - for (auto [idx, result] : llvm::enumerate(newForOp.getResults())) { - auto it = llvm::find(vectorIterArgIndices, idx); - if (it != vectorIterArgIndices.end()) { - size_t vecIdx = std::distance(vectorIterArgIndices.begin(), it); - auto shapeCast = vector::ShapeCastOp::create( - rewriter, loc, originalVectorTypes[vecIdx], result); - finalResults.push_back(shapeCast.getResult()); - } else { - finalResults.push_back(result); - } - } - - // Replace uses of the old loop's results - rewriter.replaceOp(forOp, finalResults); - - transformedOps.push_back(newForOp.getOperation()); + auto newLoop = xilinx::air::runFlattenForIterArgs(forOp, rewriter); + if (failed(newLoop)) + return emitDefiniteFailure() << "flatten-for-iter-args failed"; + transformedOps.push_back(newLoop->getOperation()); } results.set(llvm::cast(getResult()), transformedOps); @@ -4836,32 +4248,6 @@ transform::FlattenForIterArgsOp::apply(transform::TransformRewriter &rewriter, // HoistVectorTransferPointersOp //===----------------------------------------------------------------------===// -namespace { -/// Check if a value depends on the given loop induction variable -bool dependsOnLoopIVForHoist(Value val, Value loopIV) { - if (val == loopIV) - return true; - - // Check if the value is defined by an affine.apply that uses the loop IV - if (auto affineOp = val.getDefiningOp()) { - for (Value operand : affineOp.getMapOperands()) { - if (dependsOnLoopIVForHoist(operand, loopIV)) - return true; - } - } - - // Check for arithmetic operations - if (auto defOp = val.getDefiningOp()) { - for (Value operand : defOp->getOperands()) { - if (dependsOnLoopIVForHoist(operand, loopIV)) - return true; - } - } - - return false; -} -} // namespace - DiagnosedSilenceableFailure transform::HoistVectorTransferPointersOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { @@ -4875,307 +4261,14 @@ DiagnosedSilenceableFailure transform::HoistVectorTransferPointersOp::apply( } SmallVector transformedOps; - for (Operation *target : targets) { auto forOp = dyn_cast_if_present(target); - if (!forOp) { + if (!forOp) return emitDefiniteFailure() << "target must be an scf.for operation"; - } - - Value loopIV = forOp.getInductionVar(); - Location loc = forOp.getLoc(); - OpBuilder::InsertionGuard guard(rewriter); - - // Collect all vector transfer operations with IV-dependent indices - struct TransferOpInfo { - Operation *op; - Value base; - MemRefType memrefType; - VectorType vectorType; - SmallVector indices; - int64_t constantStride; // Total constant stride per iteration - bool hasIVDependentIndices; - }; - - SmallVector transferOps; - - for (Operation &op : forOp.getBody()->without_terminator()) { - auto transferOp = dyn_cast_if_present(&op); - if (!transferOp) - continue; - - Value base = transferOp.getBase(); - auto memrefType = dyn_cast_if_present(base.getType()); - if (!memrefType) - continue; - - VectorType vectorType; - if (auto readOp = dyn_cast_if_present(&op)) { - vectorType = readOp.getVectorType(); - } else if (auto writeOp = - dyn_cast_if_present(&op)) { - vectorType = writeOp.getVectorType(); - } else { - continue; - } - - SmallVector indices(transferOp.getIndices().begin(), - transferOp.getIndices().end()); - - // Check if any indices depend on loop IV and compute constant stride - bool hasIVDependentIndices = false; - int64_t constantStride = 0; - - for (size_t dimIdx = 0; dimIdx < indices.size(); ++dimIdx) { - Value idx = indices[dimIdx]; - if (dependsOnLoopIVForHoist(idx, loopIV)) { - hasIVDependentIndices = true; - - // Calculate the stride for this dimension - int64_t dimStride = 1; - for (size_t j = dimIdx + 1; - j < static_cast(memrefType.getRank()); ++j) { - dimStride *= memrefType.getShape()[j]; - } - - // For now, assume the IV coefficient is 1 (i.e., the index is IV or - // IV + const) This is the total stride increment per loop iteration - constantStride += dimStride; - } - } - - transferOps.push_back({&op, base, memrefType, vectorType, indices, - constantStride, hasIVDependentIndices}); - } - - // Prepare to add iter_args for each transfer operation with IV-dependent - // indices - SmallVector newInitArgs; - SmallVector flatMemrefs; - - for (const auto &info : transferOps) { - if (!info.hasIVDependentIndices) - continue; - - // Flatten the memref if needed - rewriter.setInsertionPoint(forOp); - Value flatMemref = info.base; - if (info.memrefType.getRank() > 1) { - int64_t totalSize = 1; - for (int64_t dim : info.memrefType.getShape()) { - if (dim == ShapedType::kDynamic) - return emitDefiniteFailure() - << "dynamic memref shapes not supported"; - totalSize *= dim; - } - - MemRefType flatMemrefType = - MemRefType::get({totalSize}, info.memrefType.getElementType(), - AffineMap(), info.memrefType.getMemorySpace()); - - SmallVector reassociation; - ReassociationIndices allDims; - for (size_t i = 0; i < static_cast(info.memrefType.getRank()); - ++i) { - allDims.push_back(i); - } - reassociation.push_back(allDims); - - flatMemref = memref::CollapseShapeOp::create( - rewriter, loc, flatMemrefType, info.base, reassociation); - } - flatMemrefs.push_back(flatMemref); - - // Compute base pointer (with zeros for IV-dependent parts) - int64_t rank = info.memrefType.getRank(); - AffineExpr linearExpr = rewriter.getAffineConstantExpr(0); - int64_t stride = 1; - for (int64_t i = rank - 1; i >= 0; --i) { - linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride; - if (i > 0) - stride *= info.memrefType.getShape()[i]; - } - auto linearMap = AffineMap::get(rank, 0, linearExpr); - - SmallVector baseIndices; - IRMapping indexMapping; - for (Value idx : info.indices) { - if (!dependsOnLoopIVForHoist(idx, loopIV)) { - if (auto defOp = idx.getDefiningOp()) { - Value clonedIdx = cloneOpAndOperands(defOp, loopIV, forOp, rewriter, - indexMapping); - if (clonedIdx) - baseIndices.push_back(clonedIdx); - else - baseIndices.push_back(idx); - } else { - baseIndices.push_back(idx); - } - } else { - baseIndices.push_back( - arith::ConstantIndexOp::create(rewriter, loc, 0)); - } - } - - Value basePointer = - affine::AffineApplyOp::create(rewriter, loc, linearMap, baseIndices); - - newInitArgs.push_back(basePointer); - } - - // If there are no IV-dependent transfers, just process them normally - if (newInitArgs.empty()) { - // Process all transfers without using iter_args - for (const auto &info : transferOps) { - rewriter.setInsertionPoint(info.op); - - // Flatten vector type - int64_t numElements = getVectorNumElements(info.vectorType); - VectorType flatVectorType = - VectorType::get({numElements}, info.vectorType.getElementType()); - - // Use the base directly - rewriter.setInsertionPoint(forOp); - Value flatMemref = info.base; - if (info.memrefType.getRank() > 1) { - int64_t totalSize = 1; - for (int64_t dim : info.memrefType.getShape()) { - totalSize *= dim; - } - MemRefType flatMemrefType = - MemRefType::get({totalSize}, info.memrefType.getElementType(), - AffineMap(), info.memrefType.getMemorySpace()); - SmallVector reassociation; - ReassociationIndices allDims; - for (size_t i = 0; i < static_cast(info.memrefType.getRank()); - ++i) { - allDims.push_back(i); - } - reassociation.push_back(allDims); - flatMemref = memref::CollapseShapeOp::create( - rewriter, loc, flatMemrefType, info.base, reassociation); - } - - // Compute pointer from indices - int64_t rank = info.memrefType.getRank(); - AffineExpr linearExpr = rewriter.getAffineConstantExpr(0); - int64_t stride = 1; - for (int64_t i = rank - 1; i >= 0; --i) { - linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride; - if (i > 0) - stride *= info.memrefType.getShape()[i]; - } - auto linearMap = AffineMap::get(rank, 0, linearExpr); - - rewriter.setInsertionPoint(info.op); - Value currentPointer = affine::AffineApplyOp::create( - rewriter, loc, linearMap, info.indices); - - // Transform the transfer operation - AffineMap identityMap1D = AffineMap::get( - 1, 0, rewriter.getAffineDimExpr(0), rewriter.getContext()); - auto inBoundsAttr = rewriter.getBoolArrayAttr({true}); - - if (auto readOp = - dyn_cast_if_present(info.op)) { - Value flatRead = vector::TransferReadOp::create( - rewriter, loc, flatVectorType, flatMemref, - ValueRange{currentPointer}, AffineMapAttr::get(identityMap1D), - readOp.getPadding(), - /*mask=*/Value(), inBoundsAttr); - Value shapedRead = vector::ShapeCastOp::create( - rewriter, loc, info.vectorType, flatRead); - rewriter.replaceOp(readOp, shapedRead); - } else if (auto writeOp = - dyn_cast_if_present(info.op)) { - Value flatValue = vector::ShapeCastOp::create( - rewriter, loc, flatVectorType, writeOp.getVector()); - rewriter.replaceOpWithNewOp( - writeOp, flatValue, flatMemref, ValueRange{currentPointer}, - AffineMapAttr::get(identityMap1D), /*mask=*/Value(), - inBoundsAttr); - } - } - transformedOps.push_back(forOp); - continue; - } - - // Use replaceWithAdditionalYields to add pointer iter_args - auto yieldValuesFn = - [&](OpBuilder &b, Location yieldLoc, - ArrayRef newBbArgs) -> SmallVector { - SmallVector yieldValues; - - // Process each transfer operation with IV-dependent indices - size_t iterArgIdx = 0; - for (size_t i = 0; i < transferOps.size(); ++i) { - const auto &info = transferOps[i]; - if (!info.hasIVDependentIndices) - continue; - - BlockArgument ptrIterArg = - newBbArgs[newBbArgs.size() - newInitArgs.size() + iterArgIdx]; - Value flatMemref = flatMemrefs[iterArgIdx]; - - // Flatten vector type - int64_t numElements = getVectorNumElements(info.vectorType); - VectorType flatVectorType = - VectorType::get({numElements}, info.vectorType.getElementType()); - - // Transform the transfer operation to use the iter_arg pointer - b.setInsertionPoint(info.op); - - AffineMap identityMap1D = - AffineMap::get(1, 0, b.getAffineDimExpr(0), b.getContext()); - auto inBoundsAttr = b.getBoolArrayAttr({true}); - - if (auto readOp = - dyn_cast_if_present(info.op)) { - Value flatRead = vector::TransferReadOp::create( - b, loc, flatVectorType, flatMemref, ValueRange{ptrIterArg}, - AffineMapAttr::get(identityMap1D), readOp.getPadding(), - /*mask=*/Value(), inBoundsAttr); - Value shapedRead = - vector::ShapeCastOp::create(b, loc, info.vectorType, flatRead); - rewriter.replaceOp(readOp, shapedRead); - } else if (auto writeOp = - dyn_cast_if_present(info.op)) { - Value flatValue = vector::ShapeCastOp::create(b, loc, flatVectorType, - writeOp.getVector()); - rewriter.replaceOpWithNewOp( - writeOp, flatValue, flatMemref, ValueRange{ptrIterArg}, - AffineMapAttr::get(identityMap1D), /*mask=*/Value(), - inBoundsAttr); - } - - // Compute next pointer value: current_ptr + constant_stride - Value strideConst = - arith::ConstantIndexOp::create(b, yieldLoc, info.constantStride); - Value nextPtr = - arith::AddIOp::create(b, yieldLoc, ptrIterArg, strideConst); - yieldValues.push_back(nextPtr); - - iterArgIdx++; - } - - return yieldValues; - }; - - // Create new loop with additional iter_args for pointers - FailureOr newLoopResult = - cast(forOp.getOperation()) - .replaceWithAdditionalYields( - rewriter, newInitArgs, // new init operands (base pointers) - true, // replace uses in loop - yieldValuesFn); - - if (failed(newLoopResult)) { - return emitDefiniteFailure() << "failed to add pointer iter_args to loop"; - } - - transformedOps.push_back(newLoopResult->getOperation()); + if (failed(xilinx::air::runHoistVectorTransferPointers(forOp, rewriter))) + return emitDefiniteFailure() << "hoist-vector-transfer-pointers failed"; + transformedOps.push_back(forOp.getOperation()); } - results.set(llvm::cast(getResult()), transformedOps); return DiagnosedSilenceableFailure::success(); } @@ -5197,290 +4290,20 @@ transform::HoistCastPairOp::apply(transform::TransformRewriter &rewriter, llvm::to_vector(state.getPayloadOps(getLoopOp())); if (extensionOps.size() != 1 || truncationOps.size() != 1 || - loopOps.size() != 1) { + loopOps.size() != 1) return emitDefiniteFailure() << "requires exactly one extension_op, " "truncation_op, and loop_op handle"; - } - Operation *extensionOp = extensionOps[0]; - Operation *truncationOp = truncationOps[0]; auto loopOp = dyn_cast_if_present(loopOps[0]); - - if (!loopOp) { + if (!loopOp) return emitDefiniteFailure() << "loop_op handle must be scf.for"; - } - - // Determine extension/truncation operation types and get input/output values - Value extensionInput, extensionOutput; - Value truncationInput, truncationOutput; - bool isFloatingPoint = false; - - if (auto extsiOp = dyn_cast_if_present(extensionOp)) { - extensionInput = extsiOp.getIn(); - extensionOutput = extsiOp.getOut(); - auto trunciOp = dyn_cast_if_present(truncationOp); - if (!trunciOp) { - return emitDefiniteFailure() - << "arith.extsi must be paired with arith.trunci"; - } - truncationInput = trunciOp.getIn(); - truncationOutput = trunciOp.getOut(); - } else if (auto extuiOp = dyn_cast_if_present(extensionOp)) { - extensionInput = extuiOp.getIn(); - extensionOutput = extuiOp.getOut(); - auto trunciOp = dyn_cast_if_present(truncationOp); - if (!trunciOp) { - return emitDefiniteFailure() - << "arith.extui must be paired with arith.trunci"; - } - truncationInput = trunciOp.getIn(); - truncationOutput = trunciOp.getOut(); - } else if (auto extfOp = dyn_cast_if_present(extensionOp)) { - extensionInput = extfOp.getIn(); - extensionOutput = extfOp.getOut(); - auto truncfOp = dyn_cast_if_present(truncationOp); - if (!truncfOp) { - return emitDefiniteFailure() - << "arith.extf must be paired with arith.truncf"; - } - truncationInput = truncfOp.getIn(); - truncationOutput = truncfOp.getOut(); - isFloatingPoint = true; - } else { - return emitDefiniteFailure() << "extension operation must be arith.extsi, " - "arith.extui, or arith.extf"; - } - - // Verify extension and truncation are in the loop - if (!loopOp->isProperAncestor(extensionOp) || - !loopOp->isProperAncestor(truncationOp)) { - return emitDefiniteFailure() - << "extension and truncation operations must be inside the loop"; - } - // Find which iter_arg the extension operates on - BlockArgument iterArg = nullptr; - int64_t iterArgIndex = -1; - vector::ShapeCastOp shapeCastBeforeExtension = nullptr; + auto newLoop = xilinx::air::runHoistCastPair( + extensionOps[0], truncationOps[0], loopOp, rewriter); + if (failed(newLoop)) + return emitDefiniteFailure() << "hoist-cast-pair failed"; - // The extension input might be the iter_arg directly, or derived from it - // through shape_cast - if (auto blockArg = dyn_cast_if_present(extensionInput)) { - if (blockArg.getOwner() == loopOp.getBody() && - blockArg.getArgNumber() > 0) { - iterArg = blockArg; - iterArgIndex = blockArg.getArgNumber() - 1; - } - } else if (auto shapeCastOp = - extensionInput.getDefiningOp()) { - Value shapeCastSource = shapeCastOp.getSource(); - if (auto blockArg = dyn_cast_if_present(shapeCastSource)) { - if (blockArg.getOwner() == loopOp.getBody() && - blockArg.getArgNumber() > 0) { - iterArg = blockArg; - iterArgIndex = blockArg.getArgNumber() - 1; - shapeCastBeforeExtension = shapeCastOp; - } - } - } - - if (!iterArg) { - return emitDefiniteFailure() << "extension must operate on a loop iter_arg " - "(directly or through shape_cast)"; - } - - // Find the value that gets yielded (should come from truncation, possibly - // through shape_cast) - vector::ShapeCastOp shapeCastAfterTruncation = nullptr; - - auto yieldOp = cast(loopOp.getBody()->getTerminator()); - bool truncationIsYielded = false; - int64_t yieldIndex = -1; - - for (auto [idx, yieldValue] : llvm::enumerate(yieldOp.getOperands())) { - if (yieldValue == truncationOutput) { - truncationIsYielded = true; - yieldIndex = idx; - break; - } else if (auto shapeCast = - yieldValue.getDefiningOp()) { - if (shapeCast.getSource() == truncationOutput) { - truncationIsYielded = true; - yieldIndex = idx; - shapeCastAfterTruncation = shapeCast; - break; - } - } - } - - if (!truncationIsYielded || yieldIndex != iterArgIndex) { - return emitDefiniteFailure() << "truncation result must be yielded at the " - "same position as the extension iter_arg"; - } - - Location loc = loopOp.getLoc(); - - // Step 1: Hoist extension before the loop (don't hoist shape_cast yet) - rewriter.setInsertionPoint(loopOp); - Value initValue = loopOp.getInitArgs()[iterArgIndex]; - - // Get the wide element type from the extension output - Type wideElemType = - cast(extensionOutput.getType()).getElementType(); - Type wideInitType = VectorType::get( - cast(initValue.getType()).getShape(), wideElemType); - - // Extend the init value directly (in narrow flat form) - Value extendedInit; - if (isFloatingPoint) { - extendedInit = - arith::ExtFOp::create(rewriter, loc, wideInitType, initValue); - } else if (isa(extensionOp)) { - extendedInit = - arith::ExtSIOp::create(rewriter, loc, wideInitType, initValue); - } else { - extendedInit = - arith::ExtUIOp::create(rewriter, loc, wideInitType, initValue); - } - - // Step 2: Create new loop with wide type for this iter_arg - SmallVector newInitArgs(loopOp.getInitArgs().begin(), - loopOp.getInitArgs().end()); - newInitArgs[iterArgIndex] = extendedInit; - - auto newLoopOp = - scf::ForOp::create(rewriter, loc, loopOp.getLowerBound(), - loopOp.getUpperBound(), loopOp.getStep(), newInitArgs); - - // Step 3: Clone the loop body with proper type adjustments - Block *oldBody = loopOp.getBody(); - Block *newBody = newLoopOp.getBody(); - - rewriter.setInsertionPointToStart(newBody); - IRMapping mapping; - - // Map the induction variable - mapping.map(oldBody->getArgument(0), newBody->getArgument(0)); - - // Map iter_args - for (auto [idx, oldArg] : - llvm::enumerate(oldBody->getArguments().drop_front(1))) { - mapping.map(oldArg, newBody->getArgument(idx + 1)); - } - - // Clone operations from old body, adjusting types as needed - for (Operation &op : oldBody->without_terminator()) { - // Skip extension - its result will be mapped to the wide iter_arg or wide - // shape_cast - if (&op == extensionOp) { - if (shapeCastBeforeExtension) { - // Map extension result to the shape_cast result (which we'll create - // below) Don't map yet - we'll map it when we encounter the shape_cast - } else { - // No shape_cast: map extension result directly to the wide iter_arg - mapping.map(extensionOutput, newBody->getArgument(iterArgIndex + 1)); - } - continue; - } - - // Skip truncation - we'll handle the yielded value specially - if (&op == truncationOp) { - continue; - } - - // Handle shape_cast before extension - clone it with wide element type - if (shapeCastBeforeExtension && - &op == shapeCastBeforeExtension.getOperation()) { - auto narrowVecType = - cast(shapeCastBeforeExtension.getResult().getType()); - auto wideVecType = - VectorType::get(narrowVecType.getShape(), wideElemType); - - Value mappedSource = mapping.lookup(shapeCastBeforeExtension.getSource()); - auto newShapeCast = - vector::ShapeCastOp::create(rewriter, loc, wideVecType, mappedSource); - mapping.map(shapeCastBeforeExtension.getResult(), - newShapeCast.getResult()); - mapping.map(extensionOutput, newShapeCast.getResult()); - continue; - } - - // Handle shape_cast after truncation - clone it with wide element type for - // the yield - if (shapeCastAfterTruncation && - &op == shapeCastAfterTruncation.getOperation()) { - // We'll handle this in the yield processing - continue; - } - - // Clone all other operations normally - rewriter.clone(op, mapping); - } - - // Step 4: Update the yield to yield the wide value - auto oldYield = cast(oldBody->getTerminator()); - SmallVector newYieldOperands; - - for (auto [idx, yieldValue] : llvm::enumerate(oldYield.getOperands())) { - if ((int64_t)idx == iterArgIndex) { - // Get the wide value (truncation input) - Value wideValue = mapping.lookup(truncationInput); - - // If there was a shape_cast after truncation, we need to create a wide - // version of it - if (shapeCastAfterTruncation) { - auto narrowVecType = - cast(shapeCastAfterTruncation.getResult().getType()); - auto wideVecType = - VectorType::get(narrowVecType.getShape(), wideElemType); - - auto newShapeCast = - vector::ShapeCastOp::create(rewriter, loc, wideVecType, wideValue); - newYieldOperands.push_back(newShapeCast.getResult()); - } else { - newYieldOperands.push_back(wideValue); - } - } else { - newYieldOperands.push_back(mapping.lookup(yieldValue)); - } - } - - scf::YieldOp::create(rewriter, loc, newYieldOperands); - - // Step 5: Hoist truncation after the loop - rewriter.setInsertionPointAfter(newLoopOp); - Value wideResult = newLoopOp.getResults()[iterArgIndex]; - - // Get the narrow element type from the original init value - auto narrowElemType = - cast(loopOp.getInitArgs()[iterArgIndex].getType()) - .getElementType(); - auto narrowResultType = VectorType::get( - cast(wideResult.getType()).getShape(), narrowElemType); - - // Create the appropriate truncation operation based on type - Value narrowResult; - if (isFloatingPoint) { - narrowResult = - arith::TruncFOp::create(rewriter, loc, narrowResultType, wideResult); - } else { - narrowResult = - arith::TruncIOp::create(rewriter, loc, narrowResultType, wideResult); - } - - // Step 6: Replace uses of the old loop - SmallVector finalResults; - for (auto [idx, result] : llvm::enumerate(newLoopOp.getResults())) { - if ((int64_t)idx == iterArgIndex) { - finalResults.push_back(narrowResult); - } else { - finalResults.push_back(result); - } - } - - rewriter.replaceOp(loopOp, finalResults); - - SmallVector resultOps = {newLoopOp.getOperation()}; + SmallVector resultOps = {newLoop->getOperation()}; results.set(llvm::cast(getResult()), resultOps); return DiagnosedSilenceableFailure::success(); } @@ -5511,38 +4334,8 @@ transform::FoldUnitExtentDimsOp::apply(transform::TransformRewriter &rewriter, auto funcOp = dyn_cast_if_present(target); if (!funcOp) return emitDefiniteFailure() << "target must be a func.func operation"; - - MLIRContext *ctx = funcOp.getContext(); - - // LLVM 23's collapseValue rejects memrefs with non-identity layouts - // (strided memrefs from subview ops). Override collapseFn to use - // rank-reducing subviews for strided memrefs, allowing the fold to - // handle linalg ops with subview outputs inside air.herd regions. - RewritePatternSet foldPatterns(ctx); - linalg::ControlDropUnitDims options; - options.collapseFn = - [](RewriterBase &rewriter, Location loc, Value operand, - ArrayRef targetShape, - ArrayRef reassociation, - const linalg::ControlDropUnitDims &control) -> FailureOr { - if (auto memrefType = dyn_cast(operand.getType())) { - if (!memrefType.getLayout().isIdentity()) { - return memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, - targetShape); - } - MemRefLayoutAttrInterface layout; - auto targetType = - MemRefType::get(targetShape, memrefType.getElementType(), layout, - memrefType.getMemorySpace()); - return memref::CollapseShapeOp::create(rewriter, loc, targetType, - operand, reassociation) - .getResult(); - } - return failure(); - }; - linalg::populateFoldUnitExtentDimsPatterns(foldPatterns, options); - (void)applyPatternsGreedily(funcOp, std::move(foldPatterns)); - + if (failed(xilinx::air::runFoldUnitExtentDimsOnFunc(funcOp))) + return emitDefiniteFailure() << "fold-unit-extent-dims failed"; transformedOps.push_back(funcOp); } @@ -5817,25 +4610,13 @@ transform::NormalizeForBoundsOp::apply(transform::TransformRewriter &rewriter, } SmallVector transformedOps; - for (Operation *target : targets) { auto forOp = dyn_cast_if_present(target); - if (!forOp) { + if (!forOp) return emitDefiniteFailure() << "target must be an scf.for operation"; - } - - // Use the utility function from AIRDependencyScheduleOpt to fold - // affine.apply into loop bounds - auto newForOp = xilinx::air::foldAffineApplyIntoLoopBounds(forOp, rewriter); - if (succeeded(newForOp)) { - // Use the returned ForOp (which may be a new operation) - transformedOps.push_back(*newForOp); - } else { - // No transformation was applied, return the original op - transformedOps.push_back(forOp); - } + transformedOps.push_back( + xilinx::air::runNormalizeForBounds(forOp, rewriter).getOperation()); } - results.set(llvm::cast(getResult()), transformedOps); return DiagnosedSilenceableFailure::success(); #else @@ -5857,5 +4638,99 @@ std::unique_ptr createAIRPipelineReducePass() { return std::make_unique(); } +//===----------------------------------------------------------------------===// +// Bufferization & fusion helpers shared between the transform.air.* op +// apply()s in this TU and the air-matmul-codegen orchestrator phases. +// Defined here because the patterns/static helpers they wrap have internal +// linkage in this TU. Declared in AIRMatmulCodegenHelpers.h. +//===----------------------------------------------------------------------===// + +LogicalResult runRemoveUninitializedCopy(func::FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.insert<::OptimizeCopyOpPattern, + ::OptimizeCopyOpPattern>(ctx); + return success(succeeded(applyPatternsGreedily(funcOp, std::move(patterns)))); +} + +LogicalResult runEliminateCascadeMemcpy(Operation *target) { + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + return success(succeeded(applyPatternsGreedily(target, std::move(patterns)))); +} + +LogicalResult runConvertMemrefCopyToLinalgCopy(Operation *target) { + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + return success(succeeded(applyPatternsGreedily(target, std::move(patterns)))); +} + +Operation *runFuseIntoContainingMemref(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); + return ::tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); +} + +bool containsOnlyTruncfOp(linalg::LinalgOp linalgOp) { + return ::containsOnlyTruncfOp(linalgOp); +} + +bool producesResultForOp(linalg::LinalgOp producerOp, + linalg::LinalgOp truncfOp) { + return ::producesResultForOp(producerOp, truncfOp); +} + +FailureOr runFuseTruncfLinalg(linalg::LinalgOp producerOp, + linalg::LinalgOp truncfOp, + RewriterBase &rewriter) { + if (!::containsOnlyTruncfOp(truncfOp)) + return failure(); + if (!::producesResultForOp(producerOp, truncfOp)) + return failure(); + FailureOr fusedOp = + ::fuseTruncfIntoProducer(rewriter, producerOp, truncfOp); + if (failed(fusedOp)) + return failure(); + + // Discardable attrs on the producer (e.g. `air.matmul_codegen_config` + // attached by an external producer) must survive the rewrite — copy them + // onto the fused/replacement op so downstream consumer passes can find them. + auto propagateDiscardable = [&](Operation *src, Operation *dst) { + for (NamedAttribute a : src->getDiscardableAttrs()) + if (!dst->hasAttr(a.getName())) + dst->setAttr(a.getName(), a.getValue()); + }; + propagateDiscardable(producerOp.getOperation(), fusedOp->getOperation()); + + // For matmul-shaped fusions (2D+ inputs), replace with linalg.matmul of the + // truncated output type so that downstream specialize/pack works. For other + // shapes, return the fused generic. + auto inputType = + dyn_cast(fusedOp->getDpsInputs()[0].getType()); + if (inputType && inputType.getRank() >= 2) { + rewriter.setInsertionPoint(*fusedOp); + auto matmulOp = linalg::MatmulOp::create( + rewriter, fusedOp->getLoc(), fusedOp->getResultTypes(), + ValueRange{fusedOp->getDpsInputs()[0], fusedOp->getDpsInputs()[1]}, + ValueRange{fusedOp->getDpsInits()[0]}); + propagateDiscardable(fusedOp->getOperation(), matmulOp.getOperation()); + rewriter.replaceOp(*fusedOp, matmulOp->getResults()); + return matmulOp.getOperation(); + } + return fusedOp->getOperation(); +} + +scf::ForOp runNormalizeForBounds(scf::ForOp forOp, RewriterBase &rewriter) { +#if AIR_ENABLE_AIE + auto newForOp = xilinx::air::foldAffineApplyIntoLoopBounds(forOp, rewriter); + if (succeeded(newForOp)) + return *newForOp; +#endif + return forOp; +} + } // namespace air } // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulBufferizationPasses.cpp b/mlir/lib/Transform/AIRMatmulBufferizationPasses.cpp new file mode 100644 index 000000000..5c1544b9c --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulBufferizationPasses.cpp @@ -0,0 +1,284 @@ +//===- AIRMatmulBufferizationPasses.cpp -------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Bufferization phases of the air-matmul-codegen orchestrator: bufferize- +// output-l2, bufferize-l1-inputs, bufferize-l1-output, post-bufferize +// cleanup, ping-pong sibling fusion, and bf16-output truncf fusion. +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulBufferizationPasses.h" + +#include "air/Dialect/AIR/AIRDialect.h" +#include "air/Transform/AIRLinalgBufferize.h" +#include "air/Transform/AIRMatmulCodegenHelpers.h" +#include "air/Transform/AIRMatmulTileL3ToL2Copies.h" +#include "air/Util/Util.h" + +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "air-matmul-bufferization-passes" + +using namespace mlir; +using namespace xilinx::air; + +namespace xilinx { +namespace air { + +namespace { + +// `findMarkedOp` / `findMarkedForLoop` live in air/Util/Util.h as +// `xilinx::air::findOpWithAttr` and `findOpOfTypeWithAttr`. + +/// Bufferize `target` into a new allocation in `memorySpace`. +/// `bufferizeDestinationOnly=true` so the targeted op itself is not rewritten; +/// only its destination operand is materialized as a fresh memref alloc. +static LogicalResult bufferizeOpToAllocation( + Operation *target, int64_t memorySpace, + linalg::BufferizeToAllocationOptions ::MemcpyOp memcpyOp, + RewriterBase &rewriter) { + linalg::BufferizeToAllocationOptions options; + options.bufferizeDestinationOnly = true; + options.emitDealloc = true; + options.memcpyOp = memcpyOp; + Attribute memSpaceAttr = + IntegerAttr::get(IntegerType::get(target->getContext(), 64), memorySpace); + Value buffer = + linalg::bufferizeToAllocation(rewriter, options, target, memSpaceAttr); + return success(buffer != nullptr); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// runBufferizeOutputL2Impl (Phase 2) +//===----------------------------------------------------------------------===// + +LogicalResult runBufferizeOutputL2Impl(func::FuncOp f, int64_t memorySpace, + bool fuseOutputTruncfFirst, + bool doTileL3ToL2Copies, int64_t kL2Tile, + StringRef copyALoopMarker, + StringRef copyBLoopMarker, + RewriterBase &rewriter) { + // Optional pre-step 1: convert memref.copy L3->L2 stagings to linalg.copy + // and tile by k-l2-tile (with copy_a_loop / copy_b_loop annotations). + if (doTileL3ToL2Copies) + if (failed(runTileL3ToL2CopiesImpl(f, kL2Tile, copyALoopMarker, + copyBLoopMarker))) + return failure(); + + // Optional pre-step 2: fuse a single-truncf linalg.generic consumer of + // the matmul into the matmul itself before bufferizing the fill, so the + // fill's element type matches the post-fuse matmul. + if (fuseOutputTruncfFirst) + runFuseOutputTruncfImpl(f, rewriter); + + SmallVector fills; + f.walk([&](linalg::FillOp op) { fills.push_back(op); }); + if (fills.empty()) + return success(); // no-op if no fill. + for (linalg::FillOp fill : fills) { + if (!fill.getOperation()->getBlock()) + continue; // erased by a prior iteration's bufferization + if (failed(bufferizeOpToAllocation( + fill, memorySpace, + linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy, + rewriter))) + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AIRMatmulBufferizeL1Output (Phase 3 tail) +//===----------------------------------------------------------------------===// + +// Free-function body for the former `air-matmul-bufferize-l1-output` pass. +// Now invoked from `air-matmul-pack-and-transpose` when its +// `do-bufferize-l1-output` option is set. +LogicalResult runBufferizeL1OutputImpl(func::FuncOp f, int64_t memorySpace, + StringRef packedMatmulMarker, + RewriterBase &rewriter) { + Operation *packedMatmul = xilinx::air::findOpWithAttr(f, packedMatmulMarker); + if (!packedMatmul) + return success(); + auto linalgOp = dyn_cast(packedMatmul); + if (!linalgOp || linalgOp.getNumDpsInits() != 1) + return packedMatmul->emitError( + "packed_matmul op must be a LinalgOp with one DPS init"); + Operation *packC = linalgOp.getDpsInits()[0].getDefiningOp(); + if (!isa_and_nonnull(packC)) + return success(); // pack already bufferized or absent. + if (failed(bufferizeOpToAllocation( + packC, memorySpace, + linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy, + rewriter))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// runBufferizeL1InputsImpl (Phase 6a) +//===----------------------------------------------------------------------===// + +LogicalResult runBufferizeL1InputsImpl(func::FuncOp f, int64_t memorySpace, + StringRef memcpyOp, StringRef lhsMarker, + StringRef rhsMarker, + RewriterBase &rewriter) { + auto memcpy = + linalg::BufferizeToAllocationOptions::MemcpyOp::MaterializeInDestination; + if (memcpyOp == "linalg-copy") + memcpy = linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy; + for (StringRef marker : {lhsMarker, rhsMarker}) { + Operation *target = xilinx::air::findOpWithAttr(f, marker); + if (!target) + continue; + if (failed(bufferizeOpToAllocation(target, memorySpace, memcpy, rewriter))) + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AIRMatmulPostBufferizeCleanup (Phase 7+8: remove uninitialized copies, +// eliminate cascade memcpys, then sibling-fuse the K-reduction loop with the +// L3->L2 copy loops for ping-pong buffering. Combined into one pass since +// the two halves are always run back-to-back.) +//===----------------------------------------------------------------------===// + +namespace { + +/// Hoist any same-block ops between `target` and `source` that are used +/// inside *either* loop's body. Required because +/// `fuseIndependentSiblingForLoops` may place the merged loop at the +/// earlier of the two source positions, leaving any in-between ops +/// (including allocs/casts the merged loop depends on) below the new +/// merged-loop position. +static void hoistInterveningDeps(scf::ForOp target, scf::ForOp source) { + Operation *first = target->isBeforeInBlock(source) ? target.getOperation() + : source.getOperation(); + Operation *second = (first == target.getOperation()) ? source.getOperation() + : target.getOperation(); + Block *block = target->getBlock(); + if (block != source->getBlock()) + return; + + llvm::SetVector toHoist; + auto collect = [&](Operation *loopRoot) { + loopRoot->walk([&](Operation *op) { + for (Value v : op->getOperands()) { + Operation *defOp = v.getDefiningOp(); + if (!defOp || defOp->getBlock() != block) + continue; + if (defOp == source.getOperation() || defOp == target.getOperation()) + continue; + if (defOp->isBeforeInBlock(first) || defOp == first) + continue; + if (second->isBeforeInBlock(defOp) || defOp == second) + continue; + toHoist.insert(defOp); + } + }); + }; + collect(target.getOperation()); + collect(source.getOperation()); + + // Sort the to-hoist set topologically and move each above `first` in + // dependency order. Operands defined outside `toHoist` are treated as + // already-ready by computeTopologicalSorting (incomplete-chain semantics). + SmallVector sorted(toHoist.begin(), toHoist.end()); + (void)mlir::computeTopologicalSorting(sorted); + for (Operation *op : sorted) + op->moveBefore(first); +} + +} // namespace + +// Free-function bodies for the prior `fuse-pingpong-loops`, +// `fuse-output-truncf`, and `hoist-static-alloc` passes. Exposed via +// AIRMatmulBufferizationPasses.h so they can be called either from the +// combined post-bufferize-cleanup pass or as option-driven steps inside +// the parametric passes (pack-and-transpose, prologue-epilogue). + +LogicalResult runFusePingpongLoopsImpl(func::FuncOp f, RewriterBase &rewriter) { + scf::ForOp copyA = + xilinx::air::findOpOfTypeWithAttr(f, "copy_a_loop"); + scf::ForOp copyB = + xilinx::air::findOpOfTypeWithAttr(f, "copy_b_loop"); + scf::ForOp kRed = + xilinx::air::findOpOfTypeWithAttr(f, "k_reduction_loop"); + if (!copyA || !copyB || !kRed) + return success(); // not in the right shape; no-op. + + scf::ForOp normalized = runNormalizeForBounds(kRed, rewriter); + hoistInterveningDeps(normalized, copyB); + if (copyB->isBeforeInBlock(normalized)) + copyB->moveBefore(normalized); + scf::ForOp afterB = + fuseIndependentSiblingForLoops(normalized, copyB, rewriter); + if (!afterB) + return failure(); + hoistInterveningDeps(afterB, copyA); + if (copyA->isBeforeInBlock(afterB)) + copyA->moveBefore(afterB); + scf::ForOp afterA = fuseIndependentSiblingForLoops(afterB, copyA, rewriter); + if (!afterA) + return failure(); + return success(); +} + +void runFuseOutputTruncfImpl(func::FuncOp f, RewriterBase &rewriter) { + // Collect all (producer, truncf_only_consumer) pairs first; fusing in- + // place mutates the IR and would invalidate a live walk. + SmallVector> pairs; + f.walk([&](linalg::LinalgOp op) { + if (!containsOnlyTruncfOp(op)) + return; + if (op.getNumDpsInputs() != 1) + return; + auto producerOp = op.getDpsInputs()[0].getDefiningOp(); + if (!producerOp) + return; + if (!producesResultForOp(producerOp, op)) + return; + pairs.emplace_back(producerOp, op); + }); + for (auto &p : pairs) { + if (!p.first->getBlock() || !p.second->getBlock()) + continue; + (void)runFuseTruncfLinalg(p.first, p.second, rewriter); + } +} + +void runHoistStaticAllocImpl(func::FuncOp f, RewriterBase &rewriter) { + hoistStaticAllocsInFunc(rewriter, + cast(f.getOperation())); +} + +// Composite of post-bufferize-cleanup: remove uninitialized copies + +// eliminate cascade memcpys + sibling-fuse pingpong loops. Now invoked +// from `air-matmul-tile-for-vectorize` when its +// `do-post-bufferize-cleanup-first` option is set. +LogicalResult runPostBufferizeCleanupImpl(func::FuncOp f, + RewriterBase &rewriter) { + if (failed(runRemoveUninitializedCopy(f))) + return failure(); + if (failed(runEliminateCascadeMemcpy(f))) + return failure(); + return runFusePingpongLoopsImpl(f, rewriter); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulCodegen.cpp b/mlir/lib/Transform/AIRMatmulCodegen.cpp new file mode 100644 index 000000000..dc776eb56 --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulCodegen.cpp @@ -0,0 +1,322 @@ +//===- AIRMatmulCodegen.cpp -------------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// AIRMatmulCodegen: single public matmul codegen pass. Internal phases are +// gated by their config (skip-if-empty) and chained with canonicalize/cse + +// upstream one-shot-bufferize. +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulCodegen.h" +#include "air/Transform/AIRMatmulBufferizationPasses.h" +#include "air/Transform/AIRMatmulCodegenHelpers.h" +#include "air/Transform/AIRMatmulPackAndTranspose.h" +#include "air/Transform/AIRMatmulTilePasses.h" +#include "air/Transform/AIRMatmulVectorizePasses.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#define DEBUG_TYPE "air-matmul-codegen" + +using namespace mlir; +using namespace xilinx::air; + +namespace xilinx { +namespace air { + +namespace { + +// Internal marker constants. The orchestrator owns the marker namespace — +// each phase tags ops with names known to the next consumer phase. Not +// configurable: callers don't need to compose phases out-of-order. +static constexpr llvm::StringLiteral kPackedMatmul = "packed_matmul"; +static constexpr llvm::StringLiteral kLaunchTileForall = "launch_tile_forall"; +static constexpr llvm::StringLiteral kCopyALoop = "copy_a_loop"; +static constexpr llvm::StringLiteral kCopyBLoop = "copy_b_loop"; +static constexpr llvm::StringLiteral kKReductionLoop = "k_reduction_loop"; +static constexpr llvm::StringLiteral kKReductionLoopInner = + "k_reduction_loop_inner"; +static constexpr llvm::StringLiteral kLhsPackInK = "lhs_pack_in_k"; +static constexpr llvm::StringLiteral kRhsPackInK = "rhs_pack_in_k"; +static constexpr llvm::StringLiteral kLhsL2PackInK = "lhs_l2_pack_in_k"; +static constexpr llvm::StringLiteral kRhsL2PackInK = "rhs_l2_pack_in_k"; +static constexpr llvm::StringLiteral kComputeForall = "compute_forall"; +static constexpr llvm::StringLiteral kMatmulCompute = "matmul_compute"; +static constexpr llvm::StringLiteral kFusedLhsL1Pack = "fused_lhs_l1_pack"; +static constexpr llvm::StringLiteral kFusedRhsL1Pack = "fused_rhs_l1_pack"; +static constexpr llvm::StringLiteral kInitFill = "init_fill"; +static constexpr llvm::StringLiteral kPrologueForall = "prologue_forall"; +static constexpr llvm::StringLiteral kEpilogueForall = "epilogue_forall"; + +class AIRMatmulCodegen : public impl::AIRMatmulCodegenBase { +public: + AIRMatmulCodegen() = default; + AIRMatmulCodegen(const AIRMatmulCodegenOptions &opts) + : AIRMatmulCodegenBase(opts) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + // Run a small pipeline at func or module scope. AIRMatmulCodegen runs at + // ModuleOp so dynamic scheduling at either scope is permitted. + bool runFuncScoped(func::FuncOp f, + llvm::function_ref populate) { + OpPassManager pm(func::FuncOp::getOperationName()); + populate(pm); + return succeeded(runPipeline(pm, f)); + } + + bool runModuleScoped(ModuleOp m, + llvm::function_ref populate) { + OpPassManager pm(ModuleOp::getOperationName()); + populate(pm); + return succeeded(runPipeline(pm, m)); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector funcs(module.getOps()); + // Phase L (one-shot bufferize) is module-scoped, so running runOnFunc on + // multiple funcs in the same module would have the first call bufferize + // the whole module and leave subsequent funcs' tensor-IR phases (A--K) + // operating on already-memref IR. All current callers compile a single + // top-level matmul kernel per module; reject anything else explicitly so + // we get a clear error instead of silent misbehavior. + if (clOneShotBufferize && funcs.size() > 1) { + module->emitError("air-matmul-codegen with one-shot-bufferize=true does " + "not support modules with more than one func.func; " + "found ") + << funcs.size() << " functions"; + return signalPassFailure(); + } + for (func::FuncOp f : funcs) + if (failed(runOnFunc(f))) + return; + } + + LogicalResult runOnFunc(func::FuncOp f) { + IRRewriter rewriter(&getContext()); + ModuleOp module = f->getParentOfType(); + auto fail = [&]() { + signalPassFailure(); + return failure(); + }; + + auto canonicalizeCse = [&]() { + return runFuncScoped(f, [](OpPassManager &pm) { + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + }); + }; + + // ---------- Phase 0: pre-fold unit-extent dims ---------- + if (failed(runFoldUnitExtentDimsOnFunc(f))) + return fail(); + + // Phase C placement: single-pack flows (no L1 pack) run bufferize-output-l2 + // BEFORE Phase A and Phase B — required by the tile-l3-to-l2-copies and + // fuse-output-truncf-first pre-steps (which must operate on un-packed IR) + // and so that the L2 alloc lands at LAUNCH scope, outside any per-core + // forall created by Phase A. + // Two-pack flows run Phase C AFTER Phase B (L2 pack) so the L2 alloc + // takes the packed shape matching the L1 pack's expected operand layout. + bool singlePackLevel = clL1PackSizes.empty(); + auto runPhaseC = [&]() -> LogicalResult { + if (!clBufferizeOutputL2) + return success(); + return runBufferizeOutputL2Impl( + f, clBufferizeOutputL2MemorySpace, clFuseOutputTruncfFirst, + clTileL3ToL2Copies, clKL2Tile, kCopyALoop, kCopyBLoop, rewriter); + }; + + if (singlePackLevel) + if (failed(runPhaseC())) + return fail(); + + // ---------- Phase A: launch tile (skip if empty) ---------- + if (!clLaunchTile.empty()) { + if (failed(runTileLaunchTileImpl(f, clLaunchTile, kLaunchTileForall, + rewriter))) + return fail(); + } + + // ---------- Phase B: L2 pack (skip if empty) ---------- + // The L2 pack bufferizes its output to L1 only in single-pack-level flows + // (l1-pack-sizes empty) AND when bufferize-last-pack-output is true. + // Two-pack-level flows defer L1 output bufferization to Phase D (L1 pack). + if (!clL2PackSizes.empty()) { + bool bufferizeL2OutputToL1 = singlePackLevel && clBufferizeLastPackOutput; + if (failed(runPackAndTransposeImpl( + f, clL2PackSizes, clL2LhsOuterPerm, clL2LhsInnerPerm, + clL2RhsOuterPerm, clL2RhsInnerPerm, clL2AccOuterPerm, + clL2AccInnerPerm, kPackedMatmul, + /*doBufferizeL1Output=*/bufferizeL2OutputToL1, + /*memSpace=*/clL1OutputMemorySpace, rewriter))) + return fail(); + if (!canonicalizeCse()) + return fail(); + } + + if (!singlePackLevel) + if (failed(runPhaseC())) + return fail(); + + // ---------- Phase D: L1 pack (skip if empty) ---------- + // The L1 pack is the LAST pack in two-pack flows, so its output is + // bufferized to L1 when bufferize-last-pack-output is true. + if (!clL1PackSizes.empty()) { + if (failed(runPackAndTransposeImpl( + f, clL1PackSizes, clL1LhsOuterPerm, clL1LhsInnerPerm, + clL1RhsOuterPerm, clL1RhsInnerPerm, clL1AccOuterPerm, + clL1AccInnerPerm, kPackedMatmul, + /*doBufferizeL1Output=*/clBufferizeLastPackOutput, + /*memSpace=*/clL1OutputMemorySpace, rewriter))) + return fail(); + } + + // ---------- Phase E: outer K-tile + fuse packs (skip if 0) ---------- + if (clOuterKTileFactor > 0) { + if (failed(runTileKAndFusePacksImpl( + f, clOuterKTileFactor, clOuterKIterIndex, kPackedMatmul, + kKReductionLoop, kLhsPackInK, kRhsPackInK, kLhsL2PackInK, + kRhsL2PackInK, rewriter))) + return fail(); + // Phase F: bufferize L2 inputs (always paired with two-pack outer-K-tile + // since the L2 packs were chain-fused). Skip if no L1 pack was done + // (single-pack-level flow doesn't have L2 packs to bufferize here). + if (!clL1PackSizes.empty()) { + if (failed(runBufferizeL1InputsImpl(f, /*memSpace=*/1, + /*memcpyOp=*/"linalg-copy", + kLhsL2PackInK, kRhsL2PackInK, + rewriter))) + return fail(); + } else if (clCoreTile.empty()) { + // Phase F': single-pack flow with NO tile-cores (e.g. a launch-tile- + // only flow). The L1 packs from Phase E are tagged lhs_pack_in_k / + // rhs_pack_in_k and need bufferization to L1 here, since Phase J + // (which uses fused_*_l1_pack markers) won't fire. + if (failed(runBufferizeL1InputsImpl(f, /*memSpace=*/2, + /*memcpyOp=*/"materialize", + kLhsPackInK, kRhsPackInK, + rewriter))) + return fail(); + } + if (!canonicalizeCse()) + return fail(); + } + + // ---------- Phase H: tile cores (skip if empty) ---------- + if (!clCoreTile.empty()) { + if (failed(runTileCoresImpl(f, clCoreTile, kPackedMatmul, kLhsPackInK, + kRhsPackInK, kComputeForall, kMatmulCompute, + kFusedLhsL1Pack, kFusedRhsL1Pack, rewriter))) + return fail(); + if (!canonicalizeCse()) + return fail(); + } + + // ---------- Phase I: inner K-tile (skip if 0) ---------- + if (clInnerKTileFactor > 0) { + if (failed(runTileKAndFusePacksImpl( + f, clInnerKTileFactor, clInnerKIterIndex, kPackedMatmul, + kKReductionLoopInner, kFusedLhsL1Pack, kFusedRhsL1Pack, + kLhsL2PackInK, kRhsL2PackInK, rewriter))) + return fail(); + } + + // ---------- Phase J: bufferize L1 inputs (skip if no tile-cores) + // ---------- + if (!clCoreTile.empty()) { + if (failed(runBufferizeL1InputsImpl(f, /*memSpace=*/2, + /*memcpyOp=*/"materialize", + kFusedLhsL1Pack, kFusedRhsL1Pack, + rewriter))) + return fail(); + if (!canonicalizeCse()) + return fail(); + } + + // ---------- Phase K: prologue/epilogue (skip if both tiles empty) + // ---------- + if (!clPrologueTile.empty() || !clEpilogueTile.empty()) { + if (failed(runPrologueEpilogueImpl(f, clPrologueTile, clEpilogueTile, + clFillIterPerm, kInitFill, + kPrologueForall, kEpilogueForall, + clHoistStaticAllocFirst, rewriter))) + return fail(); + if (!canonicalizeCse()) + return fail(); + } + + // ---------- Phase L: one-shot bufferize (gated; default true) ---------- + if (clOneShotBufferize) { + if (!runModuleScoped(module, [](OpPassManager &pm) { + bufferization::OneShotBufferizePassOptions opts; + opts.bufferizeFunctionBoundaries = true; + opts.functionBoundaryTypeConversion = + bufferization::LayoutMapOption::IdentityLayoutMap; + opts.unknownTypeConversion = + bufferization::LayoutMapOption::IdentityLayoutMap; + pm.addPass(bufferization::createOneShotBufferizePass(opts)); + })) + return fail(); + // canonicalize, cse, canonicalize (mirrors the legacy pipeline). + if (!runFuncScoped(f, [](OpPassManager &pm) { + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + })) + return fail(); + } + + // ---------- Phase M: tile for vectorize (skip if empty) ---------- + if (!clMatmulVecTile.empty()) { + if (failed(runTileForVectorizeImpl( + f, clMatmulVecTile, clMatmulUnrollVecTile, clMatmulUnrollFactor, + clFillVecTile, clPostBufferizeCleanupFirst, rewriter))) + return fail(); + } + + // ---------- Phase N: vec prep composite (always runs; no-op on + // pre-vectorize IR as the steps walk for ops that + // don't exist yet) ---------- + if (failed(runCodegenVecPrepImpl( + f, clVecPrepCast1TargetElementType, clVecPrepCast1InputIndices, + clVecPrepCast1OutputIndices, clVecPrepCast2TargetElementType, + clVecPrepCast2InputIndices, clVecPrepCast2OutputIndices, + clVecPrepHoistCastPairs, clVecPrepHoistCastPairsMaxIterations, + rewriter))) + return fail(); + + return success(); + } +}; + +} // namespace + +std::unique_ptr createAIRMatmulCodegenPass() { + return std::make_unique(); +} + +std::unique_ptr +createAIRMatmulCodegenPass(const AIRMatmulCodegenOptions &opts) { + return std::make_unique(opts); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulCodegenHelpers.cpp b/mlir/lib/Transform/AIRMatmulCodegenHelpers.cpp new file mode 100644 index 000000000..c9f69125a --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulCodegenHelpers.cpp @@ -0,0 +1,939 @@ +//===- AIRMatmulCodegenHelpers.cpp ------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulCodegenHelpers.h" +#include "air/Util/Dependency.h" +#include "air/Util/Util.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; + +namespace xilinx { +namespace air { + +//===----------------------------------------------------------------------===// +// Pure predicates / utilities. Only those needed by helpers landed so far +// are defined; others arrive as their consuming runFoo functions land. +//===----------------------------------------------------------------------===// + +static bool areEquivalentIndices(Value idx1, Value idx2) { + if (idx1 == idx2) + return true; + Operation *def1 = idx1.getDefiningOp(); + Operation *def2 = idx2.getDefiningOp(); + if (!def1 || !def2) + return false; + // affine.apply with the same map AND same operands is value-equivalent. + // air::isEquivalentTo's lite check (constants only) misses this case. + if (auto a1 = dyn_cast(def1)) { + if (auto a2 = dyn_cast(def2)) { + if (a1.getAffineMap() != a2.getAffineMap()) + return false; + if (a1.getMapOperands().size() != a2.getMapOperands().size()) + return false; + for (auto [op1, op2] : + llvm::zip(a1.getMapOperands(), a2.getMapOperands())) + if (op1 != op2) + return false; + return true; + } + } + return xilinx::air::isEquivalentTo(def1, def2); +} + +static bool dependsOnLoopIV(Value val, Value loopIV) { + if (val == loopIV) + return true; + SmallVector deps; + std::vector opHist; + xilinx::air::traceDependentInductionVar({val}, deps, opHist); + return llvm::is_contained(deps, loopIV); +} + +bool hasWritesBetweenReads(vector::TransferReadOp firstRead, + vector::TransferReadOp secondRead) { + Value sourceMemref = firstRead.getBase(); + + Block *block = firstRead->getBlock(); + if (block != secondRead->getBlock()) + return true; // Conservative: different blocks, assume writes. + + auto firstIt = firstRead->getIterator(); + auto secondIt = secondRead->getIterator(); + for (auto it = ++firstIt; it != secondIt; ++it) { + Operation *op = &(*it); + + auto memInterface = dyn_cast_if_present(op); + if (!memInterface) { + // Conservative: if effects can't be queried and op may recurse into + // nested regions with writes, assume a write. + if (!op->hasTrait()) + continue; + return true; + } + + SmallVector effects; + memInterface.getEffects(effects); + for (auto &effect : effects) { + if (!isa(effect.getEffect())) + continue; + Value effectValue = effect.getValue(); + if (!effectValue) + return true; + if (effectValue == sourceMemref) + return true; + if (auto subview = effectValue.getDefiningOp()) + if (subview.getSource() == sourceMemref) + return true; + } + } + return false; +} + +//===----------------------------------------------------------------------===// +// runFoldUnitExtentDimsOnFunc +//===----------------------------------------------------------------------===// + +LogicalResult runFoldUnitExtentDimsOnFunc(func::FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); + + RewritePatternSet foldPatterns(ctx); + linalg::ControlDropUnitDims options; + // LLVM 23's collapseValue rejects memrefs with non-identity layouts (strided + // memrefs from subview ops). Override collapseFn to use rank-reducing + // memref.subview for strided memrefs, allowing the fold to handle linalg ops + // with subview outputs inside air.herd regions. + options.collapseFn = + [](RewriterBase &rewriter, Location loc, Value operand, + ArrayRef targetShape, + ArrayRef reassociation, + const linalg::ControlDropUnitDims &control) -> FailureOr { + if (auto memrefType = dyn_cast(operand.getType())) { + if (!memrefType.getLayout().isIdentity()) { + return memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + } + MemRefLayoutAttrInterface layout; + auto targetType = + MemRefType::get(targetShape, memrefType.getElementType(), layout, + memrefType.getMemorySpace()); + return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand, + reassociation) + .getResult(); + } + return failure(); + }; + linalg::populateFoldUnitExtentDimsPatterns(foldPatterns, options); + (void)applyPatternsGreedily(funcOp, std::move(foldPatterns)); + return success(); +} + +//===----------------------------------------------------------------------===// +// runEliminateRedundantVectorTransfers +//===----------------------------------------------------------------------===// + +int runEliminateRedundantVectorTransfers(Operation *target, + RewriterBase &rewriter) { + SmallVector transferReads; + target->walk( + [&](vector::TransferReadOp readOp) { transferReads.push_back(readOp); }); + + llvm::SmallDenseSet eliminated; + int eliminatedCount = 0; + for (size_t i = 0; i < transferReads.size(); ++i) { + if (eliminated.contains(transferReads[i])) + continue; + for (size_t j = i + 1; j < transferReads.size(); ++j) { + if (eliminated.contains(transferReads[j])) + continue; + vector::TransferReadOp firstRead = transferReads[i]; + vector::TransferReadOp secondRead = transferReads[j]; + // Value-aware equivalence (matches the transform-op path in + // AIRLinalgCodegen.cpp::areIdenticalReads). OperationEquivalence is + // strict on operand SSA equality, which misses two reads whose indices + // are computed by distinct-but-identical affine.apply ops or two + // iter_args with the same initial value. + if (firstRead.getBase() != secondRead.getBase()) + continue; + if (firstRead.getIndices().size() != secondRead.getIndices().size()) + continue; + bool indicesMatch = true; + for (auto [idx1, idx2] : + llvm::zip(firstRead.getIndices(), secondRead.getIndices())) { + if (!areEquivalentIndices(idx1, idx2)) { + indicesMatch = false; + break; + } + } + if (!indicesMatch) + continue; + if (firstRead.getVector().getType() != secondRead.getVector().getType()) + continue; + if (hasWritesBetweenReads(firstRead, secondRead)) + continue; + rewriter.replaceAllUsesWith(secondRead.getResult(), + firstRead.getResult()); + rewriter.eraseOp(secondRead); + eliminated.insert(secondRead); + ++eliminatedCount; + } + } + return eliminatedCount; +} + +//===----------------------------------------------------------------------===// +// runFlattenForIterArgs +//===----------------------------------------------------------------------===// + +FailureOr runFlattenForIterArgs(scf::ForOp forOp, + RewriterBase &rewriter) { + Location loc = forOp.getLoc(); + + // Collect vector-typed iter_args. + SmallVector vectorIterArgIndices; + SmallVector originalVectorTypes; + SmallVector flattenedVectorTypes; + for (auto [idx, iterArg] : llvm::enumerate(forOp.getInitArgs())) { + if (auto vecType = dyn_cast_if_present(iterArg.getType())) { + vectorIterArgIndices.push_back(idx); + originalVectorTypes.push_back(vecType); + int64_t numElements = vecType.getNumElements(); + flattenedVectorTypes.push_back( + VectorType::get({numElements}, vecType.getElementType())); + } + } + + if (vectorIterArgIndices.empty()) + return forOp; + + // Step 1: insert shape_cast before the loop to flatten init values. + rewriter.setInsertionPoint(forOp); + SmallVector newInitArgs(forOp.getInitArgs().begin(), + forOp.getInitArgs().end()); + for (auto [idx, vecIdx] : llvm::enumerate(vectorIterArgIndices)) { + Value initArg = forOp.getInitArgs()[vecIdx]; + auto shapeCast = vector::ShapeCastOp::create( + rewriter, loc, flattenedVectorTypes[idx], initArg); + newInitArgs[vecIdx] = shapeCast.getResult(); + } + + // Step 2: build new result types. + SmallVector newResultTypes; + for (auto [idx, resultType] : llvm::enumerate(forOp.getResultTypes())) { + auto it = llvm::find(vectorIterArgIndices, idx); + if (it != vectorIterArgIndices.end()) { + size_t vecIdx = std::distance(vectorIterArgIndices.begin(), it); + newResultTypes.push_back(flattenedVectorTypes[vecIdx]); + } else { + newResultTypes.push_back(resultType); + } + } + + // Step 3: create new scf.for with flattened iter_args. + auto newForOp = + scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), newInitArgs); + + // Step 4: clone the body, inserting shape_cast back to original shape for + // vector iter_args inside the loop. + Block *oldBody = forOp.getBody(); + Block *newBody = newForOp.getBody(); + rewriter.setInsertionPointToStart(newBody); + IRMapping mapping; + mapping.map(oldBody->getArgument(0), newBody->getArgument(0)); + for (auto [idx, vecIdx] : llvm::enumerate(vectorIterArgIndices)) { + BlockArgument newArg = newBody->getArgument(vecIdx + 1); + auto shapeCast = vector::ShapeCastOp::create( + rewriter, loc, originalVectorTypes[idx], newArg); + mapping.map(oldBody->getArgument(vecIdx + 1), shapeCast.getResult()); + } + for (auto [idx, arg] : + llvm::enumerate(oldBody->getArguments().drop_front(1))) { + if (llvm::find(vectorIterArgIndices, idx) == vectorIterArgIndices.end()) + mapping.map(arg, newBody->getArgument(idx + 1)); + } + for (Operation &op : oldBody->without_terminator()) + rewriter.clone(op, mapping); + + // Step 5: rebuild yield, flattening vector values. + auto oldYield = cast(oldBody->getTerminator()); + SmallVector newYieldOperands; + for (auto [idx, yieldValue] : llvm::enumerate(oldYield.getOperands())) { + auto it = llvm::find(vectorIterArgIndices, idx); + if (it != vectorIterArgIndices.end()) { + size_t vecIdx = std::distance(vectorIterArgIndices.begin(), it); + Value mappedValue = mapping.lookup(yieldValue); + auto shapeCast = vector::ShapeCastOp::create( + rewriter, loc, flattenedVectorTypes[vecIdx], mappedValue); + newYieldOperands.push_back(shapeCast.getResult()); + } else { + newYieldOperands.push_back(mapping.lookup(yieldValue)); + } + } + scf::YieldOp::create(rewriter, loc, newYieldOperands); + + // Step 6: insert shape_cast back after the loop and replace uses. + rewriter.setInsertionPointAfter(newForOp); + SmallVector finalResults; + for (auto [idx, result] : llvm::enumerate(newForOp.getResults())) { + auto it = llvm::find(vectorIterArgIndices, idx); + if (it != vectorIterArgIndices.end()) { + size_t vecIdx = std::distance(vectorIterArgIndices.begin(), it); + auto shapeCast = vector::ShapeCastOp::create( + rewriter, loc, originalVectorTypes[vecIdx], result); + finalResults.push_back(shapeCast.getResult()); + } else { + finalResults.push_back(result); + } + } + rewriter.replaceOp(forOp, finalResults); + return newForOp; +} + +//===----------------------------------------------------------------------===// +// runHoistLoopInvariantTransfers +//===----------------------------------------------------------------------===// + +static Value cloneOpAndOperands(Operation *op, Value loopIV, scf::ForOp loopOp, + RewriterBase &rewriter, IRMapping &mapping) { + if (!op->getResults().empty()) + if (mapping.contains(op->getResult(0))) + return mapping.lookup(op->getResult(0)); + + // Producer slice filter: only clone ops that live inside the loop, are not + // already mapped, and don't transitively depend on the IV. The top-level + // loop below pre-walks `op`'s operands; this filter is what prunes the + // backward slice that air::cloneOpAndOperands then computes per-operand. + auto canClone = [loopIV, loopOp, &mapping](Operation *o) { + if (!loopOp->isAncestor(o)) + return false; + if (o->getResults().empty()) + return false; + if (mapping.contains(o->getResult(0))) + return false; + return !dependsOnLoopIV(o->getResult(0), loopIV); + }; + + for (Value operand : op->getOperands()) { + if (operand == loopIV) + continue; + if (mapping.contains(operand)) + continue; + if (isa(operand)) + continue; // Outer-loop block args still in scope. + Operation *defOp = operand.getDefiningOp(); + if (!defOp || !loopOp->isAncestor(defOp)) + continue; // Defined outside the loop, already in scope. + if (dependsOnLoopIV(operand, loopIV)) + continue; + Operation *clonedDef = + xilinx::air::cloneOpAndOperands(rewriter, mapping, defOp, canClone); + if (!clonedDef->getResults().empty()) + mapping.map(operand, clonedDef->getResult(0)); + } + + Operation *cloned = rewriter.clone(*op, mapping); + if (cloned->getResults().empty()) + return nullptr; + return cloned->getResult(0); +} + +namespace { + +/// Hoist a single transfer_read/transfer_write pair out of `loopOp`. The +/// read is cloned before the loop, the write is cloned after the loop, and +/// the accumulator value flows through a new iter_arg. +FailureOr hoistTransferPairFromLoop(vector::TransferReadOp readOp, + vector::TransferWriteOp writeOp, + scf::ForOp loopOp, + RewriterBase &rewriter) { + Value loopIV = loopOp.getInductionVar(); + + rewriter.setInsertionPoint(loopOp); + IRMapping readMapping; + Value clonedReadResult = + cloneOpAndOperands(readOp, loopIV, loopOp, rewriter, readMapping); + + Value writeVector = writeOp.getVector(); + auto yieldValuesFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBbArgs) -> SmallVector { + BlockArgument readIterArg = newBbArgs.back(); + rewriter.replaceAllUsesWith(readOp.getResult(), readIterArg); + return {writeVector}; + }; + + FailureOr newLoopResult = + cast(loopOp.getOperation()) + .replaceWithAdditionalYields(rewriter, ValueRange{clonedReadResult}, + /*replaceInitOperandUsesInLoop=*/true, + yieldValuesFn); + if (failed(newLoopResult)) + return failure(); + + auto newLoop = cast(newLoopResult->getOperation()); + rewriter.eraseOp(readOp); + + Value valueToWrite = newLoop.getResults().back(); + IRMapping writeMapping; + writeMapping.map(writeVector, valueToWrite); + rewriter.setInsertionPointAfter(newLoop); + + for (Value index : writeOp.getIndices()) { + Operation *defOp = index.getDefiningOp(); + if (!defOp || dependsOnLoopIV(index, loopIV)) + continue; + if (!newLoop->isProperAncestor(defOp)) + continue; + if (!writeMapping.contains(index)) { + Value clonedIndex = + cloneOpAndOperands(defOp, loopIV, newLoop, rewriter, writeMapping); + if (clonedIndex) + writeMapping.map(index, clonedIndex); + } + } + + rewriter.clone(*writeOp.getOperation(), writeMapping); + rewriter.eraseOp(writeOp); + return newLoop; +} + +} // namespace + +FailureOr runHoistLoopInvariantTransfers(Operation *scopeOp, + scf::ForOp loopOp, + RewriterBase &rewriter) { + if (!scopeOp->isProperAncestor(loopOp)) + return loopOp->emitError("loop must be inside the scope operation"); + + scf::ForOp currentLoop = loopOp; + while (true) { + Value loopIV = currentLoop.getInductionVar(); + vector::TransferWriteOp foundWrite = nullptr; + vector::TransferReadOp foundRead = nullptr; + + currentLoop->walk([&](vector::TransferWriteOp writeOp) { + if (foundWrite) + return; + if (writeOp->getParentOfType() != currentLoop) + return; + for (Value index : writeOp.getIndices()) + if (dependsOnLoopIV(index, loopIV)) + return; + + currentLoop->walk([&](vector::TransferReadOp readOp) { + if (foundRead) + return; + if (readOp->getParentOfType() != currentLoop) + return; + if (readOp.getBase() != writeOp.getBase()) + return; + for (Value index : readOp.getIndices()) + if (dependsOnLoopIV(index, loopIV)) + return; + if (readOp.getIndices().size() != writeOp.getIndices().size()) + return; + for (auto [ri, wi] : + llvm::zip(readOp.getIndices(), writeOp.getIndices())) + if (!areEquivalentIndices(ri, wi)) + return; + foundRead = readOp; + }); + if (foundRead) + foundWrite = writeOp; + }); + + if (!foundWrite || !foundRead) + break; + + FailureOr newLoop = + hoistTransferPairFromLoop(foundRead, foundWrite, currentLoop, rewriter); + if (failed(newLoop)) + return currentLoop->emitError("failed to hoist transfer pair"); + currentLoop = *newLoop; + } + + return currentLoop; +} + +//===----------------------------------------------------------------------===// +// runHoistVectorTransferPointers +//===----------------------------------------------------------------------===// + +LogicalResult runHoistVectorTransferPointers(scf::ForOp forOp, + RewriterBase &rewriter) { + Value loopIV = forOp.getInductionVar(); + Location loc = forOp.getLoc(); + OpBuilder::InsertionGuard guard(rewriter); + + struct TransferOpInfo { + Operation *op; + Value base; + MemRefType memrefType; + VectorType vectorType; + SmallVector indices; + int64_t constantStride; + bool hasIVDependentIndices; + }; + + SmallVector transferOps; + for (Operation &op : forOp.getBody()->without_terminator()) { + auto transferOp = dyn_cast_if_present(&op); + if (!transferOp) + continue; + Value base = transferOp.getBase(); + auto memrefType = dyn_cast_if_present(base.getType()); + if (!memrefType) + continue; + VectorType vectorType; + if (auto readOp = dyn_cast_if_present(&op)) { + vectorType = readOp.getVectorType(); + } else if (auto writeOp = + dyn_cast_if_present(&op)) { + vectorType = writeOp.getVectorType(); + } else { + continue; + } + SmallVector indices(transferOp.getIndices().begin(), + transferOp.getIndices().end()); + bool hasIVDependentIndices = false; + int64_t constantStride = 0; + for (size_t dimIdx = 0; dimIdx < indices.size(); ++dimIdx) { + Value idx = indices[dimIdx]; + if (dependsOnLoopIV(idx, loopIV)) { + hasIVDependentIndices = true; + int64_t dimStride = 1; + for (size_t j = dimIdx + 1; + j < static_cast(memrefType.getRank()); ++j) + dimStride *= memrefType.getShape()[j]; + // Assumes IV coefficient is 1 (index = IV or IV+const). This is the + // total stride increment per loop iteration. + constantStride += dimStride; + } + } + transferOps.push_back({&op, base, memrefType, vectorType, indices, + constantStride, hasIVDependentIndices}); + } + + // Prepare iter_args (one base pointer per IV-dependent transfer). + SmallVector newInitArgs; + SmallVector flatMemrefs; + for (const auto &info : transferOps) { + if (!info.hasIVDependentIndices) + continue; + rewriter.setInsertionPoint(forOp); + Value flatMemref = info.base; + if (info.memrefType.getRank() > 1) { + int64_t totalSize = 1; + for (int64_t dim : info.memrefType.getShape()) { + if (dim == ShapedType::kDynamic) + return forOp->emitError("dynamic memref shapes not supported"); + totalSize *= dim; + } + MemRefType flatMemrefType = + MemRefType::get({totalSize}, info.memrefType.getElementType(), + AffineMap(), info.memrefType.getMemorySpace()); + SmallVector reassociation; + ReassociationIndices allDims; + for (size_t i = 0; i < static_cast(info.memrefType.getRank()); + ++i) + allDims.push_back(i); + reassociation.push_back(allDims); + flatMemref = memref::CollapseShapeOp::create( + rewriter, loc, flatMemrefType, info.base, reassociation); + } + flatMemrefs.push_back(flatMemref); + + int64_t rank = info.memrefType.getRank(); + AffineExpr linearExpr = rewriter.getAffineConstantExpr(0); + int64_t stride = 1; + for (int64_t i = rank - 1; i >= 0; --i) { + linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride; + if (i > 0) + stride *= info.memrefType.getShape()[i]; + } + auto linearMap = AffineMap::get(rank, 0, linearExpr); + + SmallVector baseIndices; + IRMapping indexMapping; + for (Value idx : info.indices) { + if (!dependsOnLoopIV(idx, loopIV)) { + if (auto defOp = idx.getDefiningOp()) { + Value clonedIdx = + cloneOpAndOperands(defOp, loopIV, forOp, rewriter, indexMapping); + baseIndices.push_back(clonedIdx ? clonedIdx : idx); + } else { + baseIndices.push_back(idx); + } + } else { + baseIndices.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); + } + } + Value basePointer = + affine::AffineApplyOp::create(rewriter, loc, linearMap, baseIndices); + newInitArgs.push_back(basePointer); + } + + // No IV-dependent transfers: rewrite each transfer to a 1D form using a + // freshly-computed pointer per use, no iter_arg needed. + if (newInitArgs.empty()) { + for (const auto &info : transferOps) { + rewriter.setInsertionPoint(info.op); + int64_t numElements = info.vectorType.getNumElements(); + VectorType flatVectorType = + VectorType::get({numElements}, info.vectorType.getElementType()); + + rewriter.setInsertionPoint(forOp); + Value flatMemref = info.base; + if (info.memrefType.getRank() > 1) { + int64_t totalSize = 1; + for (int64_t dim : info.memrefType.getShape()) + totalSize *= dim; + MemRefType flatMemrefType = + MemRefType::get({totalSize}, info.memrefType.getElementType(), + AffineMap(), info.memrefType.getMemorySpace()); + SmallVector reassociation; + ReassociationIndices allDims; + for (size_t i = 0; i < static_cast(info.memrefType.getRank()); + ++i) + allDims.push_back(i); + reassociation.push_back(allDims); + flatMemref = memref::CollapseShapeOp::create( + rewriter, loc, flatMemrefType, info.base, reassociation); + } + + int64_t rank = info.memrefType.getRank(); + AffineExpr linearExpr = rewriter.getAffineConstantExpr(0); + int64_t stride = 1; + for (int64_t i = rank - 1; i >= 0; --i) { + linearExpr = linearExpr + rewriter.getAffineDimExpr(i) * stride; + if (i > 0) + stride *= info.memrefType.getShape()[i]; + } + auto linearMap = AffineMap::get(rank, 0, linearExpr); + + rewriter.setInsertionPoint(info.op); + Value currentPointer = + affine::AffineApplyOp::create(rewriter, loc, linearMap, info.indices); + + AffineMap identityMap1D = AffineMap::get( + 1, 0, rewriter.getAffineDimExpr(0), rewriter.getContext()); + auto inBoundsAttr = rewriter.getBoolArrayAttr({true}); + + if (auto readOp = dyn_cast_if_present(info.op)) { + Value flatRead = vector::TransferReadOp::create( + rewriter, loc, flatVectorType, flatMemref, + ValueRange{currentPointer}, AffineMapAttr::get(identityMap1D), + readOp.getPadding(), /*mask=*/Value(), inBoundsAttr); + Value shapedRead = vector::ShapeCastOp::create( + rewriter, loc, info.vectorType, flatRead); + rewriter.replaceOp(readOp, shapedRead); + } else if (auto writeOp = + dyn_cast_if_present(info.op)) { + Value flatValue = vector::ShapeCastOp::create( + rewriter, loc, flatVectorType, writeOp.getVector()); + rewriter.replaceOpWithNewOp( + writeOp, flatValue, flatMemref, ValueRange{currentPointer}, + AffineMapAttr::get(identityMap1D), /*mask=*/Value(), inBoundsAttr); + } + } + return success(); + } + + // IV-dependent transfers: thread base pointers as iter_args, advance by + // constant stride per iteration. + auto yieldValuesFn = + [&](OpBuilder &b, Location yieldLoc, + ArrayRef newBbArgs) -> SmallVector { + SmallVector yieldValues; + size_t iterArgIdx = 0; + for (size_t i = 0; i < transferOps.size(); ++i) { + const auto &info = transferOps[i]; + if (!info.hasIVDependentIndices) + continue; + BlockArgument ptrIterArg = + newBbArgs[newBbArgs.size() - newInitArgs.size() + iterArgIdx]; + Value flatMemref = flatMemrefs[iterArgIdx]; + + int64_t numElements = info.vectorType.getNumElements(); + VectorType flatVectorType = + VectorType::get({numElements}, info.vectorType.getElementType()); + b.setInsertionPoint(info.op); + AffineMap identityMap1D = + AffineMap::get(1, 0, b.getAffineDimExpr(0), b.getContext()); + auto inBoundsAttr = b.getBoolArrayAttr({true}); + + if (auto readOp = dyn_cast_if_present(info.op)) { + Value flatRead = vector::TransferReadOp::create( + b, loc, flatVectorType, flatMemref, ValueRange{ptrIterArg}, + AffineMapAttr::get(identityMap1D), readOp.getPadding(), + /*mask=*/Value(), inBoundsAttr); + Value shapedRead = + vector::ShapeCastOp::create(b, loc, info.vectorType, flatRead); + rewriter.replaceOp(readOp, shapedRead); + } else if (auto writeOp = + dyn_cast_if_present(info.op)) { + Value flatValue = vector::ShapeCastOp::create(b, loc, flatVectorType, + writeOp.getVector()); + rewriter.replaceOpWithNewOp( + writeOp, flatValue, flatMemref, ValueRange{ptrIterArg}, + AffineMapAttr::get(identityMap1D), /*mask=*/Value(), inBoundsAttr); + } + + Value strideConst = + arith::ConstantIndexOp::create(b, yieldLoc, info.constantStride); + Value nextPtr = + arith::AddIOp::create(b, yieldLoc, ptrIterArg, strideConst); + yieldValues.push_back(nextPtr); + ++iterArgIdx; + } + return yieldValues; + }; + + FailureOr newLoopResult = + cast(forOp.getOperation()) + .replaceWithAdditionalYields(rewriter, newInitArgs, + /*replaceInitOperandUsesInLoop=*/true, + yieldValuesFn); + if (failed(newLoopResult)) + return forOp->emitError("failed to add pointer iter_args to loop"); + return success(); +} + +//===----------------------------------------------------------------------===// +// runHoistCastPair +//===----------------------------------------------------------------------===// + +FailureOr runHoistCastPair(Operation *extensionOp, + Operation *truncationOp, + scf::ForOp loopOp, + RewriterBase &rewriter) { + Value extensionInput, extensionOutput; + Value truncationInput, truncationOutput; + bool isFloatingPoint = false; + + if (auto extsiOp = dyn_cast_if_present(extensionOp)) { + extensionInput = extsiOp.getIn(); + extensionOutput = extsiOp.getOut(); + auto trunciOp = dyn_cast_if_present(truncationOp); + if (!trunciOp) + return extensionOp->emitError( + "arith.extsi must be paired with arith.trunci"); + truncationInput = trunciOp.getIn(); + truncationOutput = trunciOp.getOut(); + } else if (auto extuiOp = dyn_cast_if_present(extensionOp)) { + extensionInput = extuiOp.getIn(); + extensionOutput = extuiOp.getOut(); + auto trunciOp = dyn_cast_if_present(truncationOp); + if (!trunciOp) + return extensionOp->emitError( + "arith.extui must be paired with arith.trunci"); + truncationInput = trunciOp.getIn(); + truncationOutput = trunciOp.getOut(); + } else if (auto extfOp = dyn_cast_if_present(extensionOp)) { + extensionInput = extfOp.getIn(); + extensionOutput = extfOp.getOut(); + auto truncfOp = dyn_cast_if_present(truncationOp); + if (!truncfOp) + return extensionOp->emitError( + "arith.extf must be paired with arith.truncf"); + truncationInput = truncfOp.getIn(); + truncationOutput = truncfOp.getOut(); + isFloatingPoint = true; + } else { + return extensionOp->emitError( + "extension operation must be arith.extsi, arith.extui, or arith.extf"); + } + + if (!loopOp->isProperAncestor(extensionOp) || + !loopOp->isProperAncestor(truncationOp)) + return loopOp->emitError( + "extension and truncation operations must be inside the loop"); + + // Find which iter_arg the extension operates on (directly or via shape_cast). + BlockArgument iterArg = nullptr; + int64_t iterArgIndex = -1; + vector::ShapeCastOp shapeCastBeforeExtension = nullptr; + if (auto blockArg = dyn_cast_if_present(extensionInput)) { + if (blockArg.getOwner() == loopOp.getBody() && + blockArg.getArgNumber() > 0) { + iterArg = blockArg; + iterArgIndex = blockArg.getArgNumber() - 1; + } + } else if (auto shapeCastOp = + extensionInput.getDefiningOp()) { + Value src = shapeCastOp.getSource(); + if (auto blockArg = dyn_cast_if_present(src)) { + if (blockArg.getOwner() == loopOp.getBody() && + blockArg.getArgNumber() > 0) { + iterArg = blockArg; + iterArgIndex = blockArg.getArgNumber() - 1; + shapeCastBeforeExtension = shapeCastOp; + } + } + } + if (!iterArg) + return extensionOp->emitError("extension must operate on a loop iter_arg " + "(directly or via shape_cast)"); + + // The yielded value must come from the truncation (possibly via shape_cast) + // and feed the same iter_arg position. + vector::ShapeCastOp shapeCastAfterTruncation = nullptr; + auto yieldOp = cast(loopOp.getBody()->getTerminator()); + bool truncationIsYielded = false; + int64_t yieldIndex = -1; + for (auto [idx, yieldValue] : llvm::enumerate(yieldOp.getOperands())) { + if (yieldValue == truncationOutput) { + truncationIsYielded = true; + yieldIndex = idx; + break; + } else if (auto shapeCast = + yieldValue.getDefiningOp()) { + if (shapeCast.getSource() == truncationOutput) { + truncationIsYielded = true; + yieldIndex = idx; + shapeCastAfterTruncation = shapeCast; + break; + } + } + } + if (!truncationIsYielded || yieldIndex != iterArgIndex) + return loopOp->emitError("truncation result must be yielded at the same " + "position as the extension iter_arg"); + + Location loc = loopOp.getLoc(); + + // Step 1: extend the init value before the loop. + rewriter.setInsertionPoint(loopOp); + Value initValue = loopOp.getInitArgs()[iterArgIndex]; + Type wideElemType = + cast(extensionOutput.getType()).getElementType(); + Type wideInitType = VectorType::get( + cast(initValue.getType()).getShape(), wideElemType); + Value extendedInit; + if (isFloatingPoint) + extendedInit = + arith::ExtFOp::create(rewriter, loc, wideInitType, initValue); + else if (isa(extensionOp)) + extendedInit = + arith::ExtSIOp::create(rewriter, loc, wideInitType, initValue); + else + extendedInit = + arith::ExtUIOp::create(rewriter, loc, wideInitType, initValue); + + // Step 2: build new loop with the wide iter_arg. + SmallVector newInitArgs(loopOp.getInitArgs().begin(), + loopOp.getInitArgs().end()); + newInitArgs[iterArgIndex] = extendedInit; + auto newLoopOp = + scf::ForOp::create(rewriter, loc, loopOp.getLowerBound(), + loopOp.getUpperBound(), loopOp.getStep(), newInitArgs); + + // Step 3: clone the loop body, adjusting types as needed. + Block *oldBody = loopOp.getBody(); + Block *newBody = newLoopOp.getBody(); + rewriter.setInsertionPointToStart(newBody); + IRMapping mapping; + mapping.map(oldBody->getArgument(0), newBody->getArgument(0)); + for (auto [idx, oldArg] : + llvm::enumerate(oldBody->getArguments().drop_front(1))) + mapping.map(oldArg, newBody->getArgument(idx + 1)); + + for (Operation &op : oldBody->without_terminator()) { + if (&op == extensionOp) { + if (!shapeCastBeforeExtension) { + // No shape_cast: extension result becomes the wide iter_arg directly. + mapping.map(extensionOutput, newBody->getArgument(iterArgIndex + 1)); + } + continue; + } + if (&op == truncationOp) + continue; // Yield handled below. + if (shapeCastBeforeExtension && + &op == shapeCastBeforeExtension.getOperation()) { + auto narrowVecType = + cast(shapeCastBeforeExtension.getResult().getType()); + auto wideVecType = + VectorType::get(narrowVecType.getShape(), wideElemType); + Value mappedSource = mapping.lookup(shapeCastBeforeExtension.getSource()); + auto newShapeCast = + vector::ShapeCastOp::create(rewriter, loc, wideVecType, mappedSource); + mapping.map(shapeCastBeforeExtension.getResult(), + newShapeCast.getResult()); + mapping.map(extensionOutput, newShapeCast.getResult()); + continue; + } + if (shapeCastAfterTruncation && + &op == shapeCastAfterTruncation.getOperation()) + continue; // Handled in yield processing. + rewriter.clone(op, mapping); + } + + // Step 4: build new yield with the wide value. + auto oldYield = cast(oldBody->getTerminator()); + SmallVector newYieldOperands; + for (auto [idx, yieldValue] : llvm::enumerate(oldYield.getOperands())) { + if ((int64_t)idx == iterArgIndex) { + Value wideValue = mapping.lookup(truncationInput); + if (shapeCastAfterTruncation) { + auto narrowVecType = + cast(shapeCastAfterTruncation.getResult().getType()); + auto wideVecType = + VectorType::get(narrowVecType.getShape(), wideElemType); + auto newShapeCast = + vector::ShapeCastOp::create(rewriter, loc, wideVecType, wideValue); + newYieldOperands.push_back(newShapeCast.getResult()); + } else { + newYieldOperands.push_back(wideValue); + } + } else { + newYieldOperands.push_back(mapping.lookup(yieldValue)); + } + } + scf::YieldOp::create(rewriter, loc, newYieldOperands); + + // Step 5: truncate the wide loop result back to narrow type. + rewriter.setInsertionPointAfter(newLoopOp); + Value wideResult = newLoopOp.getResults()[iterArgIndex]; + auto narrowElemType = + cast(loopOp.getInitArgs()[iterArgIndex].getType()) + .getElementType(); + auto narrowResultType = VectorType::get( + cast(wideResult.getType()).getShape(), narrowElemType); + Value narrowResult; + if (isFloatingPoint) + narrowResult = + arith::TruncFOp::create(rewriter, loc, narrowResultType, wideResult); + else + narrowResult = + arith::TruncIOp::create(rewriter, loc, narrowResultType, wideResult); + + // Step 6: replace uses of the old loop. + SmallVector finalResults; + for (auto [idx, result] : llvm::enumerate(newLoopOp.getResults())) { + if ((int64_t)idx == iterArgIndex) + finalResults.push_back(narrowResult); + else + finalResults.push_back(result); + } + rewriter.replaceOp(loopOp, finalResults); + return newLoopOp; +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulPackAndTranspose.cpp b/mlir/lib/Transform/AIRMatmulPackAndTranspose.cpp new file mode 100644 index 000000000..4464e9cc8 --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulPackAndTranspose.cpp @@ -0,0 +1,179 @@ +//===- AIRMatmulPackAndTranspose.cpp ---------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulPackAndTranspose.h" +#include "air/Transform/AIRMatmulBufferizationPasses.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/SmallVector.h" + +#include + +#define DEBUG_TYPE "air-matmul-pack-and-transpose" + +using namespace mlir; +using namespace xilinx::air; + +namespace xilinx { +namespace air { + +namespace { + +// Apply pack_transpose to the producer of `linalgOp` operand `operandIdx`. +// Updates `linalgOp` in-place and returns the new linalg op on success. +static FailureOr +applyOperandTranspose(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + int64_t operandIdx, ArrayRef outerPerm, + ArrayRef innerPerm) { + if (outerPerm.empty() && innerPerm.empty()) + return linalgOp; + Value operand = linalgOp->getOperand(operandIdx); + auto packOp = operand.getDefiningOp(); + if (!packOp) + return linalgOp->emitError() << "operand " << operandIdx + << " is not produced by a linalg.pack op"; + // For an output operand, packTranspose also walks to the consumer unpack. + linalg::UnPackOp maybeUnPack; + if (operandIdx == (int64_t)linalgOp.getNumDpsInputs()) { + for (auto user : linalgOp->getUsers()) { + if (auto u = dyn_cast(user)) { + maybeUnPack = u; + break; + } + } + if (!maybeUnPack) + return linalgOp->emitError() + << "output operand has no unpack consumer; cannot transpose"; + } + auto res = linalg::packTranspose(rewriter, packOp, linalgOp, maybeUnPack, + outerPerm, innerPerm); + if (failed(res)) + return linalgOp->emitError() + << "packTranspose failed for operand " << operandIdx; + return cast(res->transposedLinalgOp.getOperation()); +} + +// Apply linalg::pack + per-operand pack_transpose to a single matmul. +static LogicalResult +runOnMatmul(linalg::LinalgOp matmulOp, ArrayRef packSizes, + ArrayRef lhsOuter, ArrayRef lhsInner, + ArrayRef rhsOuter, ArrayRef rhsInner, + ArrayRef accOuter, ArrayRef accInner, + StringRef marker, RewriterBase &rewriter) { + rewriter.setInsertionPoint(matmulOp); + + // Snapshot discardable attrs (e.g. air.matmul_codegen_config) before pack + // rewrites the op into a new linalg.generic that doesn't inherit them. + SmallVector savedAttrs( + matmulOp->getDiscardableAttrs().begin(), + matmulOp->getDiscardableAttrs().end()); + + // Build OpFoldResult sizes for linalg::pack. + SmallVector packed; + packed.reserve(packSizes.size()); + for (int64_t s : packSizes) + packed.push_back(rewriter.getIndexAttr(s)); + + auto packResult = linalg::pack(rewriter, matmulOp, packed); + if (failed(packResult)) + return matmulOp->emitError() << "linalg::pack failed"; + linalg::LinalgOp current = packResult->packedLinalgOp; + + // Per-operand transposes. Operand order on the packed op: 0=LHS, 1=RHS, + // 2=accumulator (the only DPS init for matmul). + auto step = [&](int64_t idx, ArrayRef outer, + ArrayRef inner) -> LogicalResult { + auto res = applyOperandTranspose(rewriter, current, idx, outer, inner); + if (failed(res)) + return failure(); + current = *res; + return success(); + }; + if (failed(step(0, lhsOuter, lhsInner))) + return failure(); + if (failed(step(1, rhsOuter, rhsInner))) + return failure(); + if (failed(step(2, accOuter, accInner))) + return failure(); + + // Re-attach discardable attrs (the codegen config, etc.) to the final + // packed/transposed op so downstream consumer passes can read them. + for (NamedAttribute a : savedAttrs) + if (!current->hasAttr(a.getName())) + current->setDiscardableAttr(a.getName(), a.getValue()); + + if (!marker.empty()) + current->setAttr(marker, rewriter.getUnitAttr()); + return success(); +} + +} // namespace + +LogicalResult +runPackAndTransposeImpl(func::FuncOp f, ArrayRef packSizes, + ArrayRef lhsOuter, ArrayRef lhsInner, + ArrayRef rhsOuter, ArrayRef rhsInner, + ArrayRef accOuter, ArrayRef accInner, + StringRef packedMatmulMarker, bool doBufferizeL1Output, + int64_t bufferizeL1OutputMemorySpace, + RewriterBase &rewriter) { + // Find the first linalg.matmul; if none, fall back to the first + // linalg.generic carrying the `packed_matmul` marker (= already-packed + // matmul, eligible for a second pack level in two-pack flows). + linalg::LinalgOp target; + f.walk([&](linalg::MatmulOp op) { + target = cast(op.getOperation()); + return WalkResult::interrupt(); + }); + if (!target) { + f.walk([&](linalg::GenericOp op) { + if (op->hasAttr(packedMatmulMarker)) { + target = cast(op.getOperation()); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + if (!target) { + // No matmul to pack; treat as a no-op (other passes may have already + // packed it into a generic without the marker). + return success(); + } + + // Validate pack-sizes vs op iterator count. First-pack expects 3 + // (matmul m,n,k); second-pack on an already-packed op expects 6 + // (m,n,k outer + m,n,k inner) and may include zeros to leave outer + // dims unpacked. + int64_t numIters = target.getNumLoops(); + if ((int64_t)packSizes.size() != numIters) { + target->emitError() << "pack-sizes has " << packSizes.size() + << " entries; op has " << numIters << " iterators"; + return failure(); + } + + if (failed(runOnMatmul(target, packSizes, lhsOuter, lhsInner, rhsOuter, + rhsInner, accOuter, accInner, packedMatmulMarker, + rewriter))) + return failure(); + + // Optional tail step: bufferize the output linalg.pack into an L1 (or + // configurable memory-space) allocation. Replaces the former standalone + // `air-matmul-bufferize-l1-output` pass. + if (doBufferizeL1Output) { + if (failed(runBufferizeL1OutputImpl(f, bufferizeL1OutputMemorySpace, + packedMatmulMarker, rewriter))) + return failure(); + } + return success(); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulTileL3ToL2Copies.cpp b/mlir/lib/Transform/AIRMatmulTileL3ToL2Copies.cpp new file mode 100644 index 000000000..3b117219e --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulTileL3ToL2Copies.cpp @@ -0,0 +1,111 @@ +//===- AIRMatmulTileL3ToL2Copies.cpp ---------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Free-function body for the former `air-matmul-tile-l3-to-l2-copies` pass. +// Now invoked from `air-matmul-bufferize-output-l2` when its +// `do-tile-l3-to-l2-copies` option is set. +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulTileL3ToL2Copies.h" +#include "air/Transform/AIRMatmulCodegenHelpers.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/TilingInterface.h" + +#define DEBUG_TYPE "air-matmul-tile-l3-to-l2-copies" + +using namespace mlir; +using namespace xilinx::air; + +namespace xilinx { +namespace air { + +namespace { + +// Walk back from a matmul tensor operand to the linalg.copy that fills the +// memref later read by `bufferization.to_tensor`. Returns nullptr if the +// chain doesn't match the expected shape (pre-bufferization Triton-XDNA-style +// IR). +static linalg::CopyOp findCopyForOperand(Value matmulOperand) { + auto toTensor = matmulOperand.getDefiningOp(); + if (!toTensor) + return nullptr; + Value memref = toTensor.getBuffer(); + for (Operation *user : memref.getUsers()) { + auto copyOp = dyn_cast(user); + if (!copyOp) + continue; + if (copyOp.getDpsInits().size() != 1) + continue; + if (copyOp.getDpsInits()[0] == memref) + return copyOp; + } + return nullptr; +} + +// Tile a 2D linalg.copy by `tileSizes` (one OpFoldResult per dim; zero means +// not tiled). Annotates the produced scf.for with `marker` (unit attr). +static LogicalResult tileCopyAndAnnotate(linalg::CopyOp copyOp, + ArrayRef tileSizes, + StringRef marker) { + IRRewriter rewriter(copyOp.getContext()); + rewriter.setInsertionPoint(copyOp); + auto tilingIface = cast(copyOp.getOperation()); + scf::SCFTilingOptions tilingOpts; + tilingOpts.setTileSizes(tileSizes); + auto result = scf::tileUsingSCF(rewriter, tilingIface, tilingOpts); + if (failed(result)) + return copyOp->emitError() << "scf::tileUsingSCF failed"; + rewriter.replaceOp(copyOp, result->replacements); + + if (marker.empty() || result->loops.empty()) + return success(); + Operation *outerLoop = result->loops.front().getOperation(); + outerLoop->setAttr(marker, rewriter.getUnitAttr()); + return success(); +} + +} // namespace + +LogicalResult runTileL3ToL2CopiesImpl(func::FuncOp func, int64_t kL2Tile, + StringRef copyAMarker, + StringRef copyBMarker) { + if (failed(runConvertMemrefCopyToLinalgCopy(func))) + return failure(); + + linalg::MatmulOp matmul; + func.walk([&](linalg::MatmulOp op) { + matmul = op; + return WalkResult::interrupt(); + }); + if (!matmul) + return success(); // no matmul; nothing more to do. + + linalg::CopyOp copyA = findCopyForOperand(matmul->getOperand(0)); + linalg::CopyOp copyB = findCopyForOperand(matmul->getOperand(1)); + + OpBuilder b(func.getContext()); + OpFoldResult zero = b.getIndexAttr(0); + OpFoldResult kTile = b.getIndexAttr(kL2Tile); + + // LHS layout is (M, K): tile dim 1 (= K). RHS layout is (K, N): tile dim + // 0 (= K). If a copy isn't found, skip silently — re-running is a no-op. + if (copyA && failed(tileCopyAndAnnotate(copyA, {zero, kTile}, copyAMarker))) + return failure(); + if (copyB && failed(tileCopyAndAnnotate(copyB, {kTile, zero}, copyBMarker))) + return failure(); + return success(); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulTilePasses.cpp b/mlir/lib/Transform/AIRMatmulTilePasses.cpp new file mode 100644 index 000000000..5c2d5fb90 --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulTilePasses.cpp @@ -0,0 +1,477 @@ +//===- AIRMatmulTilePasses.cpp ----------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Tiling phases of the air-matmul-codegen orchestrator: launch-tile, +// tile-k-and-fuse-packs, tile-cores, prologue-epilogue. Each tiles the +// (packed) matmul on a different axis and fuses its operand-producing +// pack ops into the new loop. Markers wired so downstream phases +// (bufferize-l1-inputs, fuse-pingpong-loops) can find their targets. +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulTilePasses.h" +#include "air/Transform/AIRMatmulBufferizationPasses.h" +#include "air/Util/Util.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/TilingInterface.h" + +#include "llvm/ADT/StringRef.h" + +#define DEBUG_TYPE "air-matmul-tile-passes" + +using namespace mlir; + +namespace xilinx { +namespace air { + +namespace { + +// `findMarkedOp` lives in air/Util/Util.h as `xilinx::air::findOpWithAttr`. + +/// Build OpFoldResult-typed tile sizes (one per iterator dim) from int64s. +/// Pads with 0 if shorter than `numIters`; truncates if longer. +static SmallVector +buildTileSizes(ArrayRef sizes, int64_t numIters, MLIRContext *ctx) { + SmallVector out; + out.reserve(numIters); + OpBuilder b(ctx); + for (int64_t i = 0; i < numIters; ++i) { + int64_t v = (i < (int64_t)sizes.size()) ? sizes[i] : 0; + out.push_back(b.getIndexAttr(v)); + } + return out; +} + +/// Fuse a linalg.fill that lives just outside `forall` into the forall body +/// when its result feeds a `shared_outs` operand. After fusion the shared_outs +/// operand becomes the original fill destination (e.g. tensor.empty) and a +/// per-iter linalg.fill is cloned inside the body, before the consuming +/// linalg op, filling the corresponding extract_slice. Returns success when +/// the pattern matched and was fused. +static LogicalResult fuseFillIntoForallSharedOuts(linalg::FillOp fillOp, + scf::ForallOp forall, + RewriterBase &rewriter) { + Value fillResult = fillOp.getResult(0); + int64_t sharedOutIdx = -1; + for (auto [idx, val] : llvm::enumerate(forall.getOutputs())) { + if (val == fillResult) { + sharedOutIdx = idx; + break; + } + } + if (sharedOutIdx < 0) + return failure(); + + BlockArgument blockArg = forall.getRegionIterArgs()[sharedOutIdx]; + Value fillDest = fillOp.getOutputs()[0]; // typically tensor.empty + Value fillValue = fillOp.getInputs()[0]; + + // Find consumer of the block arg (or extract_slice on it) inside the body + // that should be re-initialized per-iter. Match a linalg op whose init + // operand is an extract_slice on blockArg. + linalg::LinalgOp consumer; + tensor::ExtractSliceOp consumerSlice; + forall.getBody()->walk([&](linalg::LinalgOp op) { + if (op.getNumDpsInits() != 1) + return WalkResult::advance(); + auto es = op.getDpsInits()[0].getDefiningOp(); + if (!es || es.getSource() != blockArg) + return WalkResult::advance(); + consumer = op; + consumerSlice = es; + return WalkResult::interrupt(); + }); + if (!consumer) + return failure(); + + // Re-source the shared_outs from the original empty (the fill destination). + forall.getOutputsMutable()[sharedOutIdx].set(fillDest); + + // Clone a per-iter fill into the body, filling the extract_slice. + rewriter.setInsertionPoint(consumer); + auto newFill = + linalg::FillOp::create(rewriter, fillOp.getLoc(), ValueRange{fillValue}, + ValueRange{consumerSlice.getResult()}); + rewriter.modifyOpInPlace(consumer, [&]() { + consumer.getDpsInitsMutable()[0].set(newFill.getResult(0)); + }); + + // Erase the outside fill (its only use is the shared_outs slot we just + // re-sourced, plus any tensor.empty chain — leave the empty for DCE). + if (fillOp.getResult(0).use_empty()) + rewriter.eraseOp(fillOp); + return success(); +} + +/// Fuse a producer LinalgOp's first tensor.extract_slice user inside `loop` +/// into the loop, returning the fused (tiled) op. This mirrors what +/// `transform.structured.fuse_into_containing_op` does for tensor producers. +static Operation *fuseProducerIntoLoop(Operation *producerOp, + LoopLikeOpInterface loop, + RewriterBase &rewriter) { + if (!producerOp || !loop) + return nullptr; + ResultRange producerResults = producerOp->getResults(); + tensor::ExtractSliceOp slice; + loop->walk([&](tensor::ExtractSliceOp s) { + if (llvm::is_contained(producerResults, s.getSource())) { + slice = s; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!slice) + return nullptr; + SmallVector loops{loop}; + auto res = scf::tileAndFuseProducerOfSlice(rewriter, slice, loops); + if (!res || res->tiledOps.empty()) + return nullptr; + return res->tiledOps.front(); +} + +/// Tile `target` with `LoopType::ForallOp` and pre-built `tileSizes`. Returns +/// the full `SCFTilingResult` on success; the original op is `replaceOp`d. +static FailureOr +tileAsForallResult(Operation *target, ArrayRef tileSizes, + RewriterBase &rewriter) { + auto tileable = dyn_cast_if_present(target); + if (!tileable) + return failure(); + rewriter.setInsertionPoint(target); + scf::SCFTilingOptions opts; + opts.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + opts.setTileSizes(tileSizes); + auto res = scf::tileUsingSCF(rewriter, tileable, opts); + if (failed(res)) + return failure(); + rewriter.replaceOp(target, res->replacements); + return res; +} + +/// Convenience wrapper around `tileAsForallResult` for callers that only need +/// the new forall loop and accept padded raw int64_t tile sizes. +static LoopLikeOpInterface tileAsForall(Operation *target, + ArrayRef tileSizes, + RewriterBase &rewriter) { + if (!target) + return {}; + auto tileable = dyn_cast(target); + if (!tileable) + return {}; + auto folded = buildTileSizes( + tileSizes, tileable.getLoopIteratorTypes().size(), target->getContext()); + auto res = tileAsForallResult(target, folded, rewriter); + if (failed(res)) + return {}; + return res->loops.empty() ? LoopLikeOpInterface() : res->loops.front(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// runTileKAndFusePacksImpl (Phase 4) +//===----------------------------------------------------------------------===// + +LogicalResult runTileKAndFusePacksImpl( + func::FuncOp f, int64_t kTileFactor, int64_t kIterIndex, + StringRef packedMatmulMarker, StringRef kReductionLoopMarker, + StringRef lhsPackMarker, StringRef rhsPackMarker, StringRef lhsL2PackMarker, + StringRef rhsL2PackMarker, RewriterBase &rewriter) { + Operation *packedMatmulOp = + xilinx::air::findOpWithAttr(f, packedMatmulMarker); + if (!packedMatmulOp) + return success(); + auto matmul = dyn_cast(packedMatmulOp); + if (!matmul) { + packedMatmulOp->emitError("packed_matmul op must be a LinalgOp"); + return failure(); + } + + // Identify pack producers of operand 0 (LHS) and operand 1 (RHS) BEFORE + // tiling — tiling rewrites the operands and would invalidate these. + Operation *packA = matmul.getDpsInputs()[0].getDefiningOp(); + Operation *packB = matmul.getDpsInputs()[1].getDefiningOp(); + + // Tile on the K iterator. Matmul iterators after pack: m0,n0,k0,m1,n1,k1 + // (3 outer + 3 inner) for standard pack [m,n,k]. K iterator index = 2. + int64_t numIters = matmul.getNumLoops(); + SmallVector raw(numIters, 0); + if (numIters < 3) { + packedMatmulOp->emitError( + "packed_matmul has fewer than 3 iterators; expected M, N, K"); + return failure(); + } + if (kIterIndex < 0 || kIterIndex >= numIters) { + packedMatmulOp->emitError("k-iter-index ") + << kIterIndex << " out of range [0, " << numIters << ")"; + return failure(); + } + raw[kIterIndex] = kTileFactor; + auto tileSizes = buildTileSizes(raw, numIters, f.getContext()); + + auto tileable = cast(packedMatmulOp); + rewriter.setInsertionPoint(packedMatmulOp); + scf::SCFTilingOptions opts; + opts.setTileSizes(tileSizes); + auto tilingResult = scf::tileUsingSCF(rewriter, tileable, opts); + if (failed(tilingResult)) { + packedMatmulOp->emitError("scf::tileUsingSCF on K failed"); + return failure(); + } + rewriter.replaceOp(packedMatmulOp, tilingResult->replacements); + + if (tilingResult->loops.empty()) + return success(); // K tile of 0; nothing more to do. + LoopLikeOpInterface kLoop = tilingResult->loops.front(); + kLoop->setAttr(kReductionLoopMarker, rewriter.getUnitAttr()); + + // Fuse pack_a and pack_b into the K loop. Annotate. For two-pack-level + // flows where the matmul's immediate operand pack (L1) has a grandparent + // pack (L2) feeding it, recursively fuse the producer chain so the L2 + // pack ends up at K-loop scope too. + auto fuseChain = [&](Operation *pack, StringRef l1Marker, + StringRef l2Marker) { + bool producerHadL1Marker = pack && pack->hasAttr(l1Marker); + Operation *fused = fuseProducerIntoLoop(pack, kLoop, rewriter); + if (!fused) + return; + if (producerHadL1Marker && pack->getBlock()) + pack->removeAttr(l1Marker); + fused->setAttr(l1Marker, rewriter.getUnitAttr()); + if (auto innerPack = dyn_cast(fused)) { + Value src = innerPack.getSource(); + while (auto es = src.getDefiningOp()) + src = es.getSource(); + if (auto gp = src.getDefiningOp()) { + if (!kLoop->isProperAncestor(gp)) { + if (Operation *l2Fused = fuseProducerIntoLoop(gp, kLoop, rewriter)) + l2Fused->setAttr(l2Marker, rewriter.getUnitAttr()); + } + } + } + }; + fuseChain(packA, lhsPackMarker, lhsL2PackMarker); + fuseChain(packB, rhsPackMarker, rhsL2PackMarker); + return success(); +} + +//===----------------------------------------------------------------------===// +// runTileCoresImpl (Phase 5) +//===----------------------------------------------------------------------===// + +LogicalResult +runTileCoresImpl(func::FuncOp f, ArrayRef tileSizes, + StringRef packedMatmulMarker, StringRef lhsPackInKMarker, + StringRef rhsPackInKMarker, StringRef computeForallMarker, + StringRef matmulComputeMarker, StringRef lhsL1PackMarker, + StringRef rhsL1PackMarker, RewriterBase &rewriter) { + Operation *packedMatmulOp = + xilinx::air::findOpWithAttr(f, packedMatmulMarker); + if (!packedMatmulOp) + return success(); + auto matmul = dyn_cast(packedMatmulOp); + if (!matmul) { + packedMatmulOp->emitError("packed_matmul op must be a LinalgOp"); + return failure(); + } + + auto folded = buildTileSizes(tileSizes, matmul.getNumLoops(), f.getContext()); + + auto tilingResult = tileAsForallResult(packedMatmulOp, folded, rewriter); + if (failed(tilingResult)) { + packedMatmulOp->emitError("scf::tileUsingSCF (forall) failed"); + return failure(); + } + + if (tilingResult->loops.empty()) + return success(); + LoopLikeOpInterface forall = tilingResult->loops.front(); + forall->setAttr(computeForallMarker, rewriter.getUnitAttr()); + + // Per-core matmul body: only one tiledOp expected. + if (!tilingResult->tiledOps.empty()) + tilingResult->tiledOps.front()->setAttr(matmulComputeMarker, + rewriter.getUnitAttr()); + + // Fuse the K-loop-fused packs into the forall. + Operation *lhsPack = xilinx::air::findOpWithAttr(f, lhsPackInKMarker); + Operation *rhsPack = xilinx::air::findOpWithAttr(f, rhsPackInKMarker); + if (Operation *fusedA = fuseProducerIntoLoop(lhsPack, forall, rewriter)) + fusedA->setAttr(lhsL1PackMarker, rewriter.getUnitAttr()); + if (Operation *fusedB = fuseProducerIntoLoop(rhsPack, forall, rewriter)) + fusedB->setAttr(rhsL1PackMarker, rewriter.getUnitAttr()); + return success(); +} + +//===----------------------------------------------------------------------===// +// runPrologueEpilogueImpl (Phase 6 prologue/epilogue) +//===----------------------------------------------------------------------===// + +LogicalResult runPrologueEpilogueImpl( + func::FuncOp f, ArrayRef prologueTileSizes, + ArrayRef epilogueTileSizes, + ArrayRef fillIteratorInterchange, StringRef initFillMarker, + StringRef prologueForallMarker, StringRef epilogueForallMarker, + bool hoistStaticAllocFirst, RewriterBase &rewriter) { + // Optional pre-step: hoist statically-bound memref.alloc ops out of + // nested loops to the function entry block. Used by two-pack-level flows + // so the L1 acc alloc lives outside the K-reduction loop (K-peel flow). + if (hoistStaticAllocFirst) + runHoistStaticAllocImpl(f, rewriter); + + // ---- Prologue: generalize+interchange+tile the linalg.fill ---- + // The prologue must execute BEFORE the compute work. Find the compute + // forall (or its ancestor scf.for) and move the fill in front of it + // before generalizing/tiling so the resulting prologue forall lands at + // the correct position. + linalg::FillOp fill; + f.walk([&](linalg::FillOp op) { + fill = op; + return WalkResult::interrupt(); + }); + if (fill) { + Operation *anchor = nullptr; + f.walk([&](scf::ForOp forOp) { + if (forOp->hasAttr("k_reduction_loop")) { + anchor = forOp.getOperation(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!anchor) { + f.walk([&](scf::ForallOp forallOp) { + if (forallOp->hasAttr("compute_forall")) { + anchor = forallOp.getOperation(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + if (anchor) { + Block *fillBlock = fill->getBlock(); + while (anchor && anchor->getBlock() != fillBlock) + anchor = anchor->getParentOp(); + if (anchor && !fill->isBeforeInBlock(anchor)) + fill->moveBefore(anchor); + } + rewriter.setInsertionPoint(fill); + FailureOr generic = + linalg::generalizeNamedOp(rewriter, fill); + if (failed(generic)) { + fill->emitError("generalizeNamedOp failed"); + return failure(); + } + generic->getOperation()->setAttr(initFillMarker, rewriter.getUnitAttr()); + + Operation *fillTileTarget = generic->getOperation(); + // Interchange iterators if a non-empty perm was provided. + if (!fillIteratorInterchange.empty()) { + SmallVector permUnsigned(fillIteratorInterchange.begin(), + fillIteratorInterchange.end()); + FailureOr interchanged = + linalg::interchangeGenericOp(rewriter, *generic, permUnsigned); + if (failed(interchanged)) { + generic->getOperation()->emitError("interchangeGenericOp failed"); + return failure(); + } + // Re-stamp the marker on the new op. + interchanged->getOperation()->setAttr(initFillMarker, + rewriter.getUnitAttr()); + fillTileTarget = interchanged->getOperation(); + } + + LoopLikeOpInterface prologueForall = + tileAsForall(fillTileTarget, prologueTileSizes, rewriter); + if (prologueForall) + prologueForall->setAttr(prologueForallMarker, rewriter.getUnitAttr()); + } + + // ---- Epilogue: tile the linalg.unpack ---- + linalg::UnPackOp unpack; + f.walk([&](linalg::UnPackOp op) { + unpack = op; + return WalkResult::interrupt(); + }); + if (unpack) { + LoopLikeOpInterface epilogueForall = + tileAsForall(unpack, epilogueTileSizes, rewriter); + if (epilogueForall) + epilogueForall->setAttr(epilogueForallMarker, rewriter.getUnitAttr()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// runTileLaunchTileImpl +//===----------------------------------------------------------------------===// + +LogicalResult runTileLaunchTileImpl(func::FuncOp f, ArrayRef tileSizes, + StringRef launchTileForallMarker, + RewriterBase &rewriter) { + linalg::MatmulOp matmul; + f.walk([&](linalg::MatmulOp op) { + matmul = op; + return WalkResult::interrupt(); + }); + if (!matmul) + return success(); + + auto folded = buildTileSizes(tileSizes, + cast(matmul.getOperation()) + .getLoopIteratorTypes() + .size(), + f.getContext()); + + // Capture the linalg.fill producer of the matmul's accumulator BEFORE + // tiling (after which the matmul is rewritten and producer linkage may + // shift through extract_slice). + Operation *fillProducer = + matmul.getOutputs()[0].getDefiningOp(); + + auto tilingResult = + tileAsForallResult(matmul.getOperation(), folded, rewriter); + if (failed(tilingResult)) { + matmul->emitError("scf::tileUsingSCF (forall) on launch-tile failed"); + return failure(); + } + + if (tilingResult->loops.empty()) + return success(); + LoopLikeOpInterface forall = tilingResult->loops.front(); + forall->setAttr(launchTileForallMarker, rewriter.getUnitAttr()); + + // Tag the inner (per-launch-tile) matmul with `matmul_compute` so that + // downstream tile-for-vectorize (which only matches inHerd ops or + // `matmul_compute`-tagged ops) can find it in launch-tile-only flows + // where there is no separate tile-cores step. The marker is preserved + // by linalg::pack (which copies discardable attrs). + if (!tilingResult->tiledOps.empty()) + tilingResult->tiledOps.front()->setAttr("matmul_compute", + rewriter.getUnitAttr()); + + if (fillProducer) { + auto fillOp = dyn_cast(fillProducer); + auto forallOp = dyn_cast(forall.getOperation()); + if (fillOp && forallOp) + (void)fuseFillIntoForallSharedOuts(fillOp, forallOp, rewriter); + } + return success(); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/AIRMatmulVectorizePasses.cpp b/mlir/lib/Transform/AIRMatmulVectorizePasses.cpp new file mode 100644 index 000000000..4daf1a1d7 --- /dev/null +++ b/mlir/lib/Transform/AIRMatmulVectorizePasses.cpp @@ -0,0 +1,400 @@ +//===- AIRMatmulVectorizePasses.cpp ----------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// +// +// Vectorization-prep phases of the air-matmul-codegen orchestrator: +// tile-for-vectorize and the vec-prep composite. Each free function walks +// a func::FuncOp and dispatches to a runFoo helper in +// AIRMatmulCodegenHelpers; the helpers are shared with the corresponding +// transform.air.* op apply() in AIRLinalgCodegen.cpp. +// +//===----------------------------------------------------------------------===// + +#include "air/Transform/AIRMatmulVectorizePasses.h" + +#include "air/Dialect/AIR/AIRDialect.h" +#include "air/Transform/AIRMatmulBufferizationPasses.h" +#include "air/Transform/AIRMatmulCodegenHelpers.h" +#include "air/Transform/PassDetail.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "air-matmul-vectorize-passes" + +using namespace mlir; +using namespace xilinx::air; + +namespace xilinx { +namespace air { + +namespace { + +// True if the herd contains at least one vector.contract — i.e., it's a +// compute herd, not a fill/epilogue herd. Mirrors the script's targeting of +// `herd2_1` specifically (the compute herd). +static bool herdHasVectorContract(xilinx::air::HerdOp herd) { + bool found = false; + herd->walk([&](mlir::vector::ContractionOp) { + found = true; + return WalkResult::interrupt(); + }); + return found; +} + +// Collect every scf.for that lives inside an air.herd in `func` and has no +// further scf.for in its subtree. Optional `herdFilter` skips entire herds. +static SmallVector +findInnermostForsInHerds(func::FuncOp func, + function_ref herdFilter = nullptr) { + SmallVector innermost; + func.walk([&](HerdOp herd) { + if (herdFilter && !herdFilter(herd)) + return; + herd->walk([&](mlir::scf::ForOp forOp) { + bool hasInnerFor = false; + for (Operation &nested : forOp.getBody()->without_terminator()) { + if (isa(nested)) { + hasInnerFor = true; + break; + } + nested.walk([&](mlir::scf::ForOp) { hasInnerFor = true; }); + if (hasInnerFor) + break; + } + if (!hasInnerFor) + innermost.push_back(forOp); + }); + }); + return innermost; +} + +// Per-step bodies. Extracted from the previously-individual AIR passes; now +// invoked in fixed order from runCodegenVecPrepImpl below. + +static LogicalResult runFlattenForIterArgsStep(func::FuncOp func, + IRRewriter &rewriter) { + SmallVector targets; + func.walk([&](mlir::scf::ForOp forOp) { + for (Value v : forOp.getInitArgs()) + if (isa(v.getType())) { + targets.push_back(forOp); + break; + } + }); + for (mlir::scf::ForOp forOp : targets) { + auto res = runFlattenForIterArgs(forOp, rewriter); + if (failed(res)) + return forOp->emitError("flatten-for-iter-args failed"); + } + return success(); +} + +static LogicalResult runHoistLoopInvariantTransfersStep(func::FuncOp func, + IRRewriter &rewriter) { + // Innermost scf.for inside each herd; the helper requires vector.transfer + // pairs in the loop's immediate body. + for (mlir::scf::ForOp loopOp : findInnermostForsInHerds(func)) { + auto scopeOp = loopOp->getParentOfType(); + auto res = runHoistLoopInvariantTransfers(scopeOp, loopOp, rewriter); + if (failed(res)) + return loopOp->emitError("hoist-loop-invariant-transfers failed"); + } + return success(); +} + +static LogicalResult runHoistVectorTransferPointersStep(func::FuncOp func, + IRRewriter &rewriter) { + // Compute-herd-only filter: skip fill/epilogue herds so downstream + // air-shrink-memref-sizes-by-access can still split L1 buffers per-core. + for (mlir::scf::ForOp forOp : + findInnermostForsInHerds(func, herdHasVectorContract)) { + if (failed(runHoistVectorTransferPointers(forOp, rewriter))) + return forOp->emitError("hoist-vector-transfer-pointers failed"); + } + return success(); +} + +static LogicalResult runVectorCastForEmulationStep(func::FuncOp func, + StringRef targetElementType, + ArrayRef inIdx, + ArrayRef outIdx, + IRRewriter &rewriter) { + if (targetElementType.empty()) + return success(); // skip + MLIRContext *ctx = func.getContext(); + Type targetTy = llvm::StringSwitch(targetElementType) + .Case("f32", Float32Type::get(ctx)) + .Case("bf16", BFloat16Type::get(ctx)) + .Case("f16", Float16Type::get(ctx)) + .Case("i32", IntegerType::get(ctx, 32)) + .Case("i16", IntegerType::get(ctx, 16)) + .Case("i8", IntegerType::get(ctx, 8)) + .Default(Type()); + if (!targetTy) + return func->emitError("unknown target-element-type '") + << targetElementType << "'"; + SmallVector targets; + func.walk([&](mlir::vector::ContractionOp c) { targets.push_back(c); }); + for (mlir::vector::ContractionOp c : targets) { + if (failed(runVectorTypeCastOnTarget(c.getOperation(), targetTy, inIdx, + outIdx, rewriter))) + return c->emitError("vector_type_cast failed"); + } + return success(); +} + +// For each vector iter_arg of `forOp`, look for an extension that operates +// on it (directly or through a single shape_cast) and a truncation whose +// result is yielded back at the same iter_arg position. +static bool findNextPair(mlir::Operation *funcOp, mlir::Operation *&extOp, + mlir::Operation *&truncOp, mlir::scf::ForOp &loopOp) { + bool found = false; + funcOp->walk([&](xilinx::air::HerdOp herd) { + if (found) + return WalkResult::interrupt(); + herd->walk([&](mlir::scf::ForOp forOp) { + if (found) + return WalkResult::interrupt(); + auto yieldOp = + dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp) + return WalkResult::advance(); + mlir::Block *body = forOp.getBody(); + for (auto [argIdx, blockArg] : + llvm::enumerate(body->getArguments().drop_front(1))) { + if (!isa(blockArg.getType())) + continue; + mlir::Operation *foundExt = nullptr; + for (mlir::Operation *user : blockArg.getUsers()) { + if (isa(user)) { + foundExt = user; + break; + } + if (auto sc = dyn_cast(user)) { + for (mlir::Operation *u2 : sc.getResult().getUsers()) { + if (isa(u2)) { + foundExt = u2; + break; + } + } + if (foundExt) + break; + } + } + if (!foundExt) + continue; + mlir::Value yieldedVal = yieldOp.getOperand((unsigned)argIdx); + mlir::Operation *foundTrunc = yieldedVal.getDefiningOp(); + if (auto sc = + dyn_cast_if_present(foundTrunc)) + foundTrunc = sc.getSource().getDefiningOp(); + if (!foundTrunc || + !isa(foundTrunc)) + continue; + extOp = foundExt; + truncOp = foundTrunc; + loopOp = forOp; + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return WalkResult::advance(); + }); + return found; +} + +static LogicalResult runHoistCastPairsStep(func::FuncOp func, + int64_t maxIterations, + IRRewriter &rewriter) { + int64_t budget = maxIterations; + while (budget-- > 0) { + mlir::Operation *extOp = nullptr; + mlir::Operation *truncOp = nullptr; + mlir::scf::ForOp loopOp; + if (!findNextPair(func.getOperation(), extOp, truncOp, loopOp)) + return success(); + auto res = runHoistCastPair(extOp, truncOp, loopOp, rewriter); + if (failed(res)) + return func->emitError("hoist-cast-pair failed"); + } + func->emitWarning( + "air-matmul-codegen-vec-prep hit hoist-cast-pairs-max-iterations cap; " + "remaining pairs not hoisted"); + return success(); +} + +// Tile a TilingInterface op by the given sizes, using scf.for. If `sizes` +// is shorter than the op's iteration domain rank, pads with zeros (matching +// `transform.structured.tile_using_for` semantics). Returns the produced +// loops on success. +static FailureOr> +tileWithScfFor(mlir::Operation *op, ArrayRef sizes, + IRRewriter &rewriter) { + auto iface = dyn_cast(op); + if (!iface) + return op->emitError("op does not implement TilingInterface"); + rewriter.setInsertionPoint(op); + mlir::scf::SCFTilingOptions opts; + unsigned numLoops = iface.getLoopIteratorTypes().size(); + if ((unsigned)sizes.size() > numLoops) + return op->emitError("tile sizes (") + << sizes.size() << ") exceed iteration domain rank (" << numLoops + << ")"; + SmallVector sizeFolds; + sizeFolds.reserve(numLoops); + for (int64_t s : sizes) + sizeFolds.push_back(rewriter.getIndexAttr(s)); + // Pad with zeros to match iteration domain rank. + while (sizeFolds.size() < numLoops) + sizeFolds.push_back(rewriter.getIndexAttr(0)); + opts.setTileSizes(sizeFolds); + auto res = mlir::scf::tileUsingSCF(rewriter, iface, opts); + if (failed(res)) + return op->emitError("tileUsingSCF failed"); + rewriter.replaceOp(op, res->replacements); + return res->loops; +} + +} // namespace + +LogicalResult runCodegenVecPrepImpl( + func::FuncOp func, StringRef cast1TargetElementType, + ArrayRef cast1InputIndices, ArrayRef cast1OutputIndices, + StringRef cast2TargetElementType, ArrayRef cast2InputIndices, + ArrayRef cast2OutputIndices, bool doHoistCastPairs, + int64_t hoistCastPairsMaxIterations, RewriterBase &rewriter) { + IRRewriter &irRewriter = static_cast(rewriter); + + if (failed(runFoldUnitExtentDimsOnFunc(func))) + return failure(); + (void)runEliminateRedundantVectorTransfers(func, irRewriter); + if (failed(runVectorCastForEmulationStep(func, cast1TargetElementType, + cast1InputIndices, + cast1OutputIndices, irRewriter))) + return failure(); + if (failed(runVectorCastForEmulationStep(func, cast2TargetElementType, + cast2InputIndices, + cast2OutputIndices, irRewriter))) + return failure(); + if (failed(runHoistLoopInvariantTransfersStep(func, irRewriter))) + return failure(); + if (failed(runFlattenForIterArgsStep(func, irRewriter))) + return failure(); + if (failed(runHoistVectorTransferPointersStep(func, irRewriter))) + return failure(); + if (doHoistCastPairs) + if (failed(runHoistCastPairsStep(func, hoistCastPairsMaxIterations, + irRewriter))) + return failure(); + return success(); +} + +LogicalResult runTileForVectorizeImpl(func::FuncOp func, + ArrayRef matmulTileSizes, + ArrayRef matmulUnrollTileSizes, + int64_t matmulUnrollFactor, + ArrayRef fillTileSizes, + bool doPostBufferizeCleanupFirst, + RewriterBase &rewriter) { + IRRewriter &irRewriter = static_cast(rewriter); + + // Optional pre-step: post-bufferize cleanup (remove uninitialized + // copies + eliminate cascade memcpys + sibling-fuse pingpong loops). + // Replaces the former standalone `air-matmul-post-bufferize-cleanup` + // pass. + if (doPostBufferizeCleanupFirst) + if (failed(runPostBufferizeCleanupImpl(func, rewriter))) + return failure(); + + // Phase 1: tile each linalg.generic packed-matmul body by matmulTileSizes. + // Accept ops that either (a) live inside an air.herd (iron-built flow) + // or (b) carry the `matmul_compute` marker (linalg-input flow runs this + // pass BEFORE the forall->herd materialization). + SmallVector matmulGenerics; + func.walk([&](mlir::linalg::GenericOp op) { + bool inHerd = op->getParentOfType() != nullptr; + bool isMatmulCompute = op->hasAttr("matmul_compute"); + if (!inHerd && !isMatmulCompute) + return; + if (op.getNumLoops() < (int64_t)matmulTileSizes.size()) + return; + matmulGenerics.push_back(op); + }); + for (mlir::linalg::GenericOp gen : matmulGenerics) { + auto loops1 = + tileWithScfFor(gen.getOperation(), matmulTileSizes, irRewriter); + if (failed(loops1)) + return failure(); + // After first tile, find the new inner linalg.generic (the only + // descendant of the produced loops). + mlir::linalg::GenericOp inner; + if (!loops1->empty()) { + loops1->back()->walk([&](mlir::linalg::GenericOp g) { + inner = g; + return WalkResult::interrupt(); + }); + } else { + inner = gen; // No tiling happened (zero sizes). Skip second tile. + } + if (!inner) + continue; + auto loops2 = + tileWithScfFor(inner.getOperation(), matmulUnrollTileSizes, irRewriter); + if (failed(loops2)) + return failure(); + // Unroll the two innermost produced loops. + // loops2->back() is the innermost; loops2 is in outer→inner order. + uint64_t factor = matmulUnrollFactor; + if (factor > 1) { + SmallVector toUnroll; + for (auto loop : *loops2) + if (auto sf = dyn_cast(loop.getOperation())) + toUnroll.push_back(sf); + // Unroll from innermost outward (last two). + for (auto it = toUnroll.rbegin(); + it != toUnroll.rend() && std::distance(toUnroll.rbegin(), it) < 2; + ++it) { + if (failed(mlir::loopUnrollByFactor(*it, factor))) { + it->emitError("loopUnrollByFactor failed"); + return failure(); + } + } + } + } + + // Phase 2: tile each linalg.fill (or linalg.generic carrying the + // `init_fill` marker, set by the prologue-epilogue phase after + // generalize+interchange) by fillTileSizes. + SmallVector fills; + func.walk([&](mlir::linalg::FillOp f) { + if (f->getParentOfType()) + fills.push_back(f.getOperation()); + }); + func.walk([&](mlir::linalg::GenericOp g) { + if (g->hasAttr("init_fill")) + fills.push_back(g.getOperation()); + }); + for (mlir::Operation *f : fills) { + auto loops = tileWithScfFor(f, fillTileSizes, irRewriter); + if (failed(loops)) + return failure(); + } + return success(); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Transform/CMakeLists.txt b/mlir/lib/Transform/CMakeLists.txt index 39d8c535f..8b4c411b7 100644 --- a/mlir/lib/Transform/CMakeLists.txt +++ b/mlir/lib/Transform/CMakeLists.txt @@ -23,6 +23,13 @@ list(APPEND TRANSFORM_SOURCES AIRLinalgCodegen.cpp AIRLinalgOpStats.cpp AIRLoopMergingPass.cpp + AIRMatmulBufferizationPasses.cpp + AIRMatmulCodegen.cpp + AIRMatmulCodegenHelpers.cpp + AIRMatmulPackAndTranspose.cpp + AIRMatmulTileL3ToL2Copies.cpp + AIRMatmulTilePasses.cpp + AIRMatmulVectorizePasses.cpp ) if(AIR_ENABLE_AIE) list(APPEND TRANSFORM_SOURCES diff --git a/mlir/lib/Transform/Passes.cpp b/mlir/lib/Transform/Passes.cpp index ed2f0601b..5cd5989bc 100644 --- a/mlir/lib/Transform/Passes.cpp +++ b/mlir/lib/Transform/Passes.cpp @@ -46,6 +46,7 @@ void xilinx::air::registerTransformPasses() { registerAIRLoopMergingPass(); registerAIRLoopPermutation(); registerAIRLowerHerdParallelPass(); + registerAIRMatmulCodegen(); registerAIROverrideMemRefMemorySpace(); registerAIRPipelineReducePass(); registerAIRRegularizeLoop(); diff --git a/mlir/lib/Util/CMakeLists.txt b/mlir/lib/Util/CMakeLists.txt index 2cfa6fa96..358b3554c 100644 --- a/mlir/lib/Util/CMakeLists.txt +++ b/mlir/lib/Util/CMakeLists.txt @@ -23,6 +23,7 @@ add_mlir_library(AIRUtil Dependency.cpp DependencyDot.cpp DirectedAdjacencyMap.cpp + MatmulCodegenConfig.cpp DEPENDS AIRDialect diff --git a/mlir/lib/Util/MatmulCodegenConfig.cpp b/mlir/lib/Util/MatmulCodegenConfig.cpp new file mode 100644 index 000000000..f1d210fdf --- /dev/null +++ b/mlir/lib/Util/MatmulCodegenConfig.cpp @@ -0,0 +1,101 @@ +//===- MatmulCodegenConfig.cpp ----------------------------------*- C++ -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +#include "air/Util/MatmulCodegenConfig.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +using namespace mlir; + +namespace xilinx { +namespace air { + +std::optional findMatmulCodegenConfig(func::FuncOp funcOp) { + StringRef name = getMatmulCodegenConfigAttrName(); + std::optional found; + funcOp.walk([&](Operation *op) { + if (auto attr = op->getDiscardableAttr(name)) { + if (auto dict = dyn_cast(attr)) { + found = dict; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return found; +} + +SmallVector getI64Array(DictionaryAttr cfg, StringRef key) { + SmallVector out; + if (!cfg) + return out; + auto entry = cfg.get(key); + auto arr = dyn_cast_if_present(entry); + if (!arr) + return out; + for (Attribute a : arr) { + if (auto i = dyn_cast(a)) + out.push_back(i.getInt()); + } + return out; +} + +int64_t getI64(DictionaryAttr cfg, StringRef key, int64_t defaultVal) { + if (!cfg) + return defaultVal; + auto entry = cfg.get(key); + if (auto i = dyn_cast_if_present(entry)) + return i.getInt(); + return defaultVal; +} + +bool getBool(DictionaryAttr cfg, StringRef key, bool defaultVal) { + if (!cfg) + return defaultVal; + auto entry = cfg.get(key); + if (auto b = dyn_cast_if_present(entry)) + return b.getValue(); + return defaultVal; +} + +bool writeMatmulCodegenConfig(func::FuncOp funcOp, DictionaryAttr dict, + StringRef markerName) { + StringRef name = getMatmulCodegenConfigAttrName(); + Operation *target = nullptr; + if (!markerName.empty()) { + funcOp.walk([&](Operation *op) { + if (op->hasAttr(markerName)) { + target = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + if (!target) { + funcOp.walk([&](linalg::MatmulOp op) { + target = op.getOperation(); + return WalkResult::interrupt(); + }); + } + if (!target) + return false; + target->setDiscardableAttr(name, dict); + return true; +} + +DictionaryAttr buildMatmulCodegenConfig(MLIRContext *ctx, + ArrayRef entries) { + SmallVector filtered; + filtered.reserve(entries.size()); + for (const NamedAttribute &e : entries) + if (e.getValue()) + filtered.push_back(e); + return DictionaryAttr::get(ctx, filtered); +} + +} // namespace air +} // namespace xilinx diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 8981d7e7b..cfcdf1249 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -2350,6 +2350,18 @@ Operation *air::cloneOpAndOperands(RewriterBase &rewriter, IRMapping &remap, return new_op; } +Operation *air::findOpWithAttr(Operation *root, StringRef attrName) { + Operation *found = nullptr; + root->walk([&](Operation *op) { + if (op->hasAttr(attrName)) { + found = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + bool air::opOrAncestorIsDominantOver(Operation *a, Operation *b) { Region *commonRegion = air::findCommonRegionContainingAllAncestors( SmallVector{a, b}, nullptr); diff --git a/mlir/test/Transform/AIRMatmulPackAndTranspose/pack_basic.mlir b/mlir/test/Transform/AIRMatmulPackAndTranspose/pack_basic.mlir new file mode 100644 index 000000000..f337da469 --- /dev/null +++ b/mlir/test/Transform/AIRMatmulPackAndTranspose/pack_basic.mlir @@ -0,0 +1,47 @@ +//===- pack_basic.mlir ------------------------------------------*- MLIR -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +// RUN: air-opt %s -air-matmul-codegen='l2-pack-sizes=8,8,8 \ +// RUN: bufferize-last-pack-output=false' \ +// RUN: | FileCheck %s --check-prefix=NOPERM +// RUN: air-opt %s -air-matmul-codegen='l2-pack-sizes=8,8,8 \ +// RUN: l2-lhs-outer-perm=1,0 l2-rhs-outer-perm=1,0 l2-rhs-inner-perm=1,0 \ +// RUN: l2-acc-outer-perm=1,0 \ +// RUN: bufferize-last-pack-output=false' \ +// RUN: | FileCheck %s --check-prefix=ALLPERM + +// The accumulator pack of a zero-filled empty tensor is folded by the +// orchestrator's post-pack canonicalize into a single tensor.empty + +// linalg.fill in the packed shape; only the LHS and RHS packs survive. + +// NOPERM-LABEL: func.func @matmul_pack_basic +// NOPERM: linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 8] +// NOPERM: linalg.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [8, 8] +// NOPERM: linalg.fill {{.*}} -> tensor<32x16x8x8xf32> +// NOPERM: linalg.generic +// NOPERM-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] +// NOPERM-SAME: packed_matmul +// NOPERM: linalg.unpack + +// Test 54-style transposes: outer_perm=[1,0] on LHS, RHS, ACC + inner_perm=[1,0] on RHS. +// LHS (M,K) → outer-transposed to (K,M). +// RHS originally inner_dims_pos=[1,0]; outer_perm + inner_perm both [1,0] → inner_dims_pos=[0,1]. +// ALLPERM-LABEL: func.func @matmul_pack_basic +// ALLPERM: linalg.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 8] +// ALLPERM: linalg.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 8] +// ALLPERM: linalg.fill +// ALLPERM: linalg.generic +// ALLPERM-SAME: packed_matmul +// ALLPERM: linalg.unpack %{{.*}} outer_dims_perm = [1, 0] + +func.func @matmul_pack_basic(%a: tensor<256x784xf32>, %b: tensor<784x128xf32>) -> tensor<256x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<256x128xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x128xf32>) -> tensor<256x128xf32> + %2 = linalg.matmul ins(%a, %b : tensor<256x784xf32>, tensor<784x128xf32>) outs(%1 : tensor<256x128xf32>) -> tensor<256x128xf32> + return %2 : tensor<256x128xf32> +} diff --git a/mlir/test/Transform/AIRMatmulTileL3ToL2Copies/tile_copies_basic.mlir b/mlir/test/Transform/AIRMatmulTileL3ToL2Copies/tile_copies_basic.mlir new file mode 100644 index 000000000..361bd441a --- /dev/null +++ b/mlir/test/Transform/AIRMatmulTileL3ToL2Copies/tile_copies_basic.mlir @@ -0,0 +1,51 @@ +//===- tile_copies_basic.mlir -----------------------------------*- MLIR -*-===// +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +// Triton-XDNA-style input: matmul preceded by L3->L2 memref.copy stagings. +// Verifies (1) memref.copy → linalg.copy conversion, (2) per-operand K-tiling, +// (3) loop annotations. + +// RUN: air-opt %s '-air-matmul-codegen=bufferize-output-l2=true tile-l3-to-l2-copies=true k-l2-tile=16' | FileCheck %s + +// CHECK-LABEL: func.func @matmul_with_l3_l2_copies +// LHS copy (64x784) is tiled by [0, 16] → outer scf.for over K, copy of 64x16 tiles. +// CHECK: memref.alloc() : memref<64x784xf32> +// CHECK: scf.for +// CHECK: memref.subview {{.*}} [64, 16] [1, 1] +// CHECK: memref.subview {{.*}} [64, 16] [1, 1] +// CHECK: linalg.copy ins(%{{.*}} : memref<64x16xf32{{.*}}>) outs(%{{.*}} : memref<64x16xf32{{.*}}>) +// CHECK: } {copy_a_loop} +// RHS copy (784x32) is tiled by [16, 0] → outer scf.for over K, copy of 16x32 tiles. +// CHECK: memref.alloc() : memref<784x32xf32> +// CHECK: scf.for +// CHECK: memref.subview {{.*}} [16, 32] [1, 1] +// CHECK: memref.subview {{.*}} [16, 32] [1, 1] +// CHECK: linalg.copy ins(%{{.*}} : memref<16x32xf32{{.*}}>) outs(%{{.*}} : memref<16x32xf32{{.*}}>) +// CHECK: } {copy_b_loop} +// CHECK: linalg.matmul + +func.func @matmul_with_l3_l2_copies(%argA: memref<*xf32>, %argB: memref<*xf32>, %argC: memref<*xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %reinterpret_a = memref.reinterpret_cast %argA to offset: [%c0], sizes: [64, 784], strides: [784, 1] : memref<*xf32> to memref<64x784xf32, strided<[784, 1], offset: ?>> + %alloc_a = memref.alloc() : memref<64x784xf32> + memref.copy %reinterpret_a, %alloc_a : memref<64x784xf32, strided<[784, 1], offset: ?>> to memref<64x784xf32> + %ta = bufferization.to_tensor %alloc_a restrict writable : memref<64x784xf32> to tensor<64x784xf32> + + %reinterpret_b = memref.reinterpret_cast %argB to offset: [%c0], sizes: [784, 32], strides: [32, 1] : memref<*xf32> to memref<784x32xf32, strided<[32, 1], offset: ?>> + %alloc_b = memref.alloc() : memref<784x32xf32> + memref.copy %reinterpret_b, %alloc_b : memref<784x32xf32, strided<[32, 1], offset: ?>> to memref<784x32xf32> + %tb = bufferization.to_tensor %alloc_b restrict writable : memref<784x32xf32> to tensor<784x32xf32> + + %tc_init = tensor.empty() : tensor<64x32xf32> + %tc_fill = linalg.fill ins(%cst : f32) outs(%tc_init : tensor<64x32xf32>) -> tensor<64x32xf32> + %tc = linalg.matmul ins(%ta, %tb : tensor<64x784xf32>, tensor<784x32xf32>) outs(%tc_fill : tensor<64x32xf32>) -> tensor<64x32xf32> + + %reinterpret_c = memref.reinterpret_cast %argC to offset: [%c0], sizes: [64, 32], strides: [32, 1] : memref<*xf32> to memref<64x32xf32, strided<[32, 1], offset: ?>> + bufferization.materialize_in_destination %tc in writable %reinterpret_c : (tensor<64x32xf32>, memref<64x32xf32, strided<[32, 1], offset: ?>>) -> () + return +} diff --git a/programming_examples/matrix_multiplication/bf16/run.py b/programming_examples/matrix_multiplication/bf16/run.py index 54159d9df..b46d497c3 100644 --- a/programming_examples/matrix_multiplication/bf16/run.py +++ b/programming_examples/matrix_multiplication/bf16/run.py @@ -7,10 +7,11 @@ from ml_dtypes import bfloat16 from air.ir import * +import air.passmanager from air.dialects.affine import apply as affine_apply from air.dialects.linalg import fill from air.dialects.air import * -from air.dialects.arith import ConstantOp +from air.dialects.arith import ConstantOp, MulIOp from air.dialects.memref import AllocOp, DeallocOp, load, store, subview from air.dialects.func import FuncOp from air.dialects.scf import for_, yield_ @@ -55,10 +56,13 @@ def build_module( arch="aie2", direct_codegen=False, ): - assert m % tile_m == 0 + # M, N must already be padded up to (tile_m * herd_m) / (tile_n * herd_n) + # by the caller (see padded-shape computation in __main__). K must be a + # full multiple of the L2 tile. + assert m % (tile_m * herd_m) == 0 + assert n % (tile_n * herd_n) == 0 assert k % tile_k_l2 == 0 assert tile_k_l2 % tile_k_l1 == 0 - assert n % tile_n == 0 a_size = [m, k] b_size = [k, n] c_size = [m, n] @@ -202,29 +206,17 @@ def segment_body( # semantics. l1_c_data = AllocOp(l1MemrefTyCHerd, [], []) - # Affine map for launch iv - launch_ix_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_m * herd_m), - ) - ], - ) - launch_iy_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_n * herd_n), - ) - ], - ) - launch_offset_x = affine_apply(launch_ix_map, [launch_ivx_s]) - launch_offset_y = affine_apply(launch_iy_map, [launch_ivy_s]) + # arith.muli of the launch block ID is the form + # air-split-launch-for-padding looks for when partitioning + # the launch into interior + tail tiles based on + # air.actual_sizes (see inferTileSize in + # mlir/lib/Transform/AIRSplitLaunchForPadding.cpp). Using + # affine.apply here would prevent that pass from inferring + # the launch tile size and would break hardware padding. + launch_tile_m_const = ConstantOp.create_index(tile_m * herd_m).result + launch_tile_n_const = ConstantOp.create_index(tile_n * herd_n).result + launch_offset_x = MulIOp(launch_ivx_s, launch_tile_m_const).result + launch_offset_y = MulIOp(launch_ivy_s, launch_tile_n_const).result @herd( name="herd_0", @@ -566,10 +558,20 @@ def herd_body( print("Peano is needed for direct code generation mode.", file=sys.stderr) sys.exit(1) + # Hardware padding: round M, N up to a multiple of the launch tile + # (tile_m * herd_m, tile_n * herd_n). The IR is built for the padded + # shape; aircc's air-split-launch-for-padding partitions the launch into + # interior + tail tiles when air.actual_sizes is set on air.launch. + launch_tile_m = args.tile_m * args.herd_m + launch_tile_n = args.tile_n * args.herd_n + m_padded = math.ceil(args.m / launch_tile_m) * launch_tile_m + n_padded = math.ceil(args.n / launch_tile_n) * launch_tile_n + needs_padding = (args.m != m_padded) or (args.n != n_padded) + mlir_module = build_module( - args.m, + m_padded, args.k, - args.n, + n_padded, args.tile_m, args.tile_k_l2, args.tile_k_l1, @@ -582,136 +584,51 @@ def herd_body( args.direct_codegen, ) - # Vectorization - only run if direct codegen mode is enabled + # Attach air.actual_sizes to the air.launch op iff the user-requested + # shape needs padding. Aircc's air-split-launch-for-padding pass reads + # this and is a no-op when the attribute is absent (so aligned shapes + # produce byte-identical IR to before this change). + if needs_padding: + with mlir_module.context, Location.unknown(): + actual_sizes_attr = Attribute.parse(f"array") + found = [None] + + def _visit(op): + if op.operation.name == "air.launch": + op.operation.attributes["air.actual_sizes"] = actual_sizes_attr + found[0] = op + return WalkResult.INTERRUPT + return WalkResult.ADVANCE + + mlir_module.operation.walk(_visit) + assert found[0] is not None, "no air.launch op produced by build_module" + + # Direct-codegen flow: only the vectorize stages of the C++ orchestrator + # (tile-for-vectorize + vec-prep). All earlier phases are skipped. if args.direct_codegen: - transform_ir_string = ( - """ - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func0 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op - - - %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - %inner_most_matmul, %vec_loops:3 = - transform.structured.tile_using_for %matmul tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_matmul tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op - - %linalg_fills = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %linalg_fills tile_sizes [0, 0, 1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op - - %herd1, %herd2, %herd3 = transform.split_handle %vectorized_herds : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %scf_fors = transform.structured.match ops{["scf.for"]} in %herd2 : (!transform.any_op) -> !transform.any_op - - %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func1 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op - - // Eliminate redundant vector.transfer_read operations - %func1_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func1_rematch : (!transform.any_op) -> !transform.any_op - - // Hoist loop-invariant vector transfers out of innermost loop - %herds_1 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds_1 = transform.air.herd_vectorize %herds_1 : (!transform.any_op) -> !transform.any_op - %herd1_1, %herd2_1, %herd3_1 = transform.split_handle %vectorized_herds_1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - - // Hoist all accumulator transfer pairs from the innermost loop - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op - - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op - - %fors_to_hoist_ptrs = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for1, %outer_fors1 = transform.split_handle %fors_to_hoist_ptrs {overflow_result = 1}: (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - """ - + ( - """ - // Hoist the 4 extf/truncf pairs from the innermost loop - // (only applicable when output is bf16, producing paired extf/truncf ops) - %all_extf_loop = transform.structured.match ops{["arith.extf"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop = transform.structured.match ops{["arith.truncf"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - - // Split to get individual operations (4 extf total) - %extf_bf16_1, %extf_bf16_2, %extf_bf16_3, %extf_bf16_4 = transform.split_handle %all_extf_loop : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // The 4 truncf ops correspond to the 4 vector.contract results - %truncf_1, %truncf_2, %truncf_3, %truncf_4 = transform.split_handle %all_truncf_loop : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // Hoist first pair - %for1_1_hoisted_1 = transform.air.hoist_cast_pair %extf_bf16_1, %truncf_1, %innermost_for1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - // Re-match and hoist second pair - %all_extf_loop_2 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_2 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %extf_bf16_2_new, %e2_5, %e2_6 = transform.split_handle %all_extf_loop_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %truncf_2_1, %truncf_2_2, %truncf_2_3 = transform.split_handle %all_truncf_loop_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_2 = transform.air.hoist_cast_pair %extf_bf16_2_new, %truncf_2_1, %for1_1_hoisted_1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - // Re-match and hoist third pair - %all_extf_loop_3 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_3 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %extf_bf16_3_new, %e3_7 = transform.split_handle %all_extf_loop_3 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %truncf_3_1, %truncf_3_2 = transform.split_handle %all_truncf_loop_3 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %for1_1_hoisted_3 = transform.air.hoist_cast_pair %extf_bf16_3_new, %truncf_3_1, %for1_1_hoisted_2 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - // Re-match and hoist fourth pair - %all_extf_loop_4 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_4 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %for1_1_hoisted_final = transform.air.hoist_cast_pair %all_extf_loop_4, %all_truncf_loop_4, %for1_1_hoisted_3 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - """ - if OUTPUT_DATATYPE == bfloat16 - else "" - ) - + """ - - %func2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_3 = transform.air.fold_unit_extent_dims %func_fold_3 : (!transform.any_op) -> !transform.any_op - transform.yield - } - } - """ - ) - transform_ir = Module.parse(transform_ir_string, context=mlir_module.context) - run_transform(transform_ir, mlir_module) + hoist_pairs = "true" if OUTPUT_DATATYPE == bfloat16 else "false" + steps = [ + "func.func(canonicalize,cse)", + "air-matmul-codegen{" + "matmul-vec-tile=2,2,1,0,0,0 " + "matmul-unroll-vec-tile=1,1,0,0,0,0 " + "matmul-unroll-factor=2 fill-vec-tile=0,0,1,1 " + "}", + "func.func(air-herd-vectorize)", + "func.func(canonicalize,cse,fold-memref-alias-ops)", + # Vec-prep composite: eliminate-redundant + cast(f32) + hoist-loop + + # flatten + hoist-pointers + (bf16-out: hoist-cast-pairs). + "air-matmul-codegen{" + "vec-prep-cast1-target-element-type=f32 " + "vec-prep-cast1-input-indices=2 " + "vec-prep-cast1-output-indices=0 " + f"vec-prep-hoist-cast-pairs={hoist_pairs}" + "}", + "func.func(canonicalize,cse,fold-memref-alias-ops)", + ] + pipeline = "builtin.module(" + ",".join(steps) + ")" + pm = air.passmanager.PassManager.parse(pipeline, context=mlir_module.context) + pm.run(mlir_module.operation) if args.print_module_only: print(mlir_module) exit(0) @@ -719,13 +636,25 @@ def herd_body( # Variance-normalized inputs following PyTorch's # random_matrix_with_scaled_reduction_dim: randn / sqrt(K). # This keeps output variance ~1 regardless of K, so relative - # tolerance behaves consistently across matrix sizes. + # tolerance behaves consistently across matrix sizes. Buffers are + # allocated at the padded shape (m_padded, k) / (k, n_padded) so the + # kernel can read/write whole launch tiles. When padding is needed, + # the tail rows/columns are zero so the resulting output tail is also + # zero and contributes no error to interior tiles. scale = 1.0 / math.sqrt(args.k) - input_a = (np.random.randn(args.m, args.k) * scale).astype(INPUT_DATATYPE) - input_b = (np.random.randn(args.k, args.n) * scale).astype(INPUT_DATATYPE) + input_a = np.zeros((m_padded, args.k), dtype=INPUT_DATATYPE) + input_a[: args.m, :] = (np.random.randn(args.m, args.k) * scale).astype( + INPUT_DATATYPE + ) + input_b = np.zeros((args.k, n_padded), dtype=INPUT_DATATYPE) + input_b[:, : args.n] = (np.random.randn(args.k, args.n) * scale).astype( + INPUT_DATATYPE + ) if args.compile_mode == "compile-and-run": # Stochastically sample results and pass to XRTRunner for verification. + # Indices are clamped to the actual M, N (no point checking the + # zero-padded tail). num_samples = 100 sampled_indices = np.vstack( [ @@ -749,8 +678,10 @@ def herd_body( dtype=OUTPUT_DATATYPE, ) + # Output buffer comes back at the padded shape; we only validate + # entries in the [0:M, 0:N] interior. sampled_data = { - "shape": (args.m, args.n), + "shape": (m_padded, n_padded), "indices": sampled_indices, "values": sampled_values, } diff --git a/programming_examples/matrix_multiplication/i16/run.py b/programming_examples/matrix_multiplication/i16/run.py index bda88ed70..6f3cc1ead 100644 --- a/programming_examples/matrix_multiplication/i16/run.py +++ b/programming_examples/matrix_multiplication/i16/run.py @@ -1,6 +1,7 @@ # Copyright (C) 2025, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT import argparse +import math import os import sys @@ -8,7 +9,7 @@ from air.dialects.affine import apply as affine_apply from air.dialects.linalg import fill from air.dialects.air import * -from air.dialects.arith import ConstantOp +from air.dialects.arith import ConstantOp, MulIOp from air.dialects.memref import AllocOp, DeallocOp, load, store, subview from air.dialects.func import FuncOp from air.dialects.scf import for_, yield_ @@ -52,10 +53,12 @@ def build_module( np_dtype_out, arch="aie2", ): - assert m % tile_m == 0 + # M, N must already be padded up to (tile_m * herd_m) / (tile_n * herd_n) + # by the caller (see padded-shape computation in __main__). + assert m % (tile_m * herd_m) == 0 + assert n % (tile_n * herd_n) == 0 assert k % tile_k_l2 == 0 assert tile_k_l2 % tile_k_l1 == 0 - assert n % tile_n == 0 a_size = [m, k] b_size = [k, n] c_size = [m, n] @@ -197,29 +200,15 @@ def segment_body( # semantics. l1_c_data = AllocOp(l1MemrefTyCHerd, [], []) - # Affine map for launch iv - launch_ix_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_m * herd_m), - ) - ], - ) - launch_iy_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_n * herd_n), - ) - ], - ) - launch_offset_x = affine_apply(launch_ix_map, [launch_ivx_s]) - launch_offset_y = affine_apply(launch_iy_map, [launch_ivy_s]) + # arith.muli on launch block IDs is the form + # air-split-launch-for-padding looks for when partitioning + # the launch into interior + tail tiles based on + # air.actual_sizes (see inferTileSize in + # mlir/lib/Transform/AIRSplitLaunchForPadding.cpp). + launch_tile_m_const = ConstantOp.create_index(tile_m * herd_m).result + launch_tile_n_const = ConstantOp.create_index(tile_n * herd_n).result + launch_offset_x = MulIOp(launch_ivx_s, launch_tile_m_const).result + launch_offset_y = MulIOp(launch_ivy_s, launch_tile_n_const).result @herd( name="herd_0", @@ -544,10 +533,19 @@ def herd_body( print("Peano is needed for direct code generation mode.", file=sys.stderr) sys.exit(1) + # Hardware padding: round M, N up to a multiple of the launch tile. + # Aircc's air-split-launch-for-padding partitions the launch into + # interior + tail tiles when air.actual_sizes is set on air.launch. + launch_tile_m = args.tile_m * args.herd_m + launch_tile_n = args.tile_n * args.herd_n + m_padded = math.ceil(args.m / launch_tile_m) * launch_tile_m + n_padded = math.ceil(args.n / launch_tile_n) * launch_tile_n + needs_padding = (args.m != m_padded) or (args.n != n_padded) + mlir_module = build_module( - args.m, + m_padded, args.k, - args.n, + n_padded, args.tile_m, args.tile_k_l2, args.tile_k_l1, @@ -559,6 +557,21 @@ def herd_body( args.arch, ) + if needs_padding: + with mlir_module.context, Location.unknown(): + actual_sizes_attr = Attribute.parse(f"array") + found = [None] + + def _visit(op): + if op.operation.name == "air.launch": + op.operation.attributes["air.actual_sizes"] = actual_sizes_attr + found[0] = op + return WalkResult.INTERRUPT + return WalkResult.ADVANCE + + mlir_module.operation.walk(_visit) + assert found[0] is not None, "no air.launch op produced by build_module" + # Vectorization - only run if direct codegen mode is enabled if args.direct_codegen: # Architecture-specific accumulator type for vector intrinsics @@ -685,10 +698,16 @@ def herd_body( print(mlir_module) exit(0) - input_a = np.arange(0, args.m * args.k, dtype=np.int64).reshape(args.m, args.k) % 7 - input_a = input_a.astype(INPUT_DATATYPE) - input_b = np.arange(0, args.k * args.n, dtype=np.int64).reshape(args.k, args.n) % 7 - input_b = input_b.astype(INPUT_DATATYPE) + # Buffers allocated at the padded shape; tail rows/cols stay zero so + # the matmul output's tail is also zero (and is not validated). + input_a = np.zeros((m_padded, args.k), dtype=INPUT_DATATYPE) + input_a[: args.m, :] = ( + np.arange(0, args.m * args.k, dtype=np.int64).reshape(args.m, args.k) % 7 + ).astype(INPUT_DATATYPE) + input_b = np.zeros((args.k, n_padded), dtype=INPUT_DATATYPE) + input_b[:, : args.n] = ( + np.arange(0, args.k * args.n, dtype=np.int64).reshape(args.k, args.n) % 7 + ).astype(INPUT_DATATYPE) if args.compile_mode == "compile-and-run": @@ -716,9 +735,10 @@ def herd_body( dtype=OUTPUT_DATATYPE, ) - # Store as a dictionary + # Output comes back at the padded shape; only validate the + # [0:M, 0:N] interior. sampled_data = { - "shape": (args.m, args.n), + "shape": (m_padded, n_padded), "indices": sampled_indices, "values": sampled_values, } diff --git a/programming_examples/matrix_multiplication/i8/run.py b/programming_examples/matrix_multiplication/i8/run.py index bac0278ec..c851f91d0 100644 --- a/programming_examples/matrix_multiplication/i8/run.py +++ b/programming_examples/matrix_multiplication/i8/run.py @@ -1,14 +1,16 @@ # Copyright (C) 2025, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT import argparse +import math import os import sys from air.ir import * +import air.passmanager from air.dialects.affine import apply as affine_apply from air.dialects.linalg import fill from air.dialects.air import * -from air.dialects.arith import ConstantOp +from air.dialects.arith import ConstantOp, MulIOp from air.dialects.memref import AllocOp, DeallocOp, load, store, subview from air.dialects.func import FuncOp from air.dialects.scf import for_, yield_ @@ -52,10 +54,12 @@ def build_module( np_dtype_out, arch="aie2", ): - assert m % tile_m == 0 + # M, N must already be padded up to (tile_m * herd_m) / (tile_n * herd_n) + # by the caller (see padded-shape computation in __main__). + assert m % (tile_m * herd_m) == 0 + assert n % (tile_n * herd_n) == 0 assert k % tile_k_l2 == 0 assert tile_k_l2 % tile_k_l1 == 0 - assert n % tile_n == 0 a_size = [m, k] b_size = [k, n] c_size = [m, n] @@ -197,29 +201,15 @@ def segment_body( # semantics. l1_c_data = AllocOp(l1MemrefTyCHerd, [], []) - # Affine map for launch iv - launch_ix_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_m * herd_m), - ) - ], - ) - launch_iy_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_n * herd_n), - ) - ], - ) - launch_offset_x = affine_apply(launch_ix_map, [launch_ivx_s]) - launch_offset_y = affine_apply(launch_iy_map, [launch_ivy_s]) + # arith.muli on launch block IDs is the form + # air-split-launch-for-padding looks for when partitioning + # the launch into interior + tail tiles based on + # air.actual_sizes (see inferTileSize in + # mlir/lib/Transform/AIRSplitLaunchForPadding.cpp). + launch_tile_m_const = ConstantOp.create_index(tile_m * herd_m).result + launch_tile_n_const = ConstantOp.create_index(tile_n * herd_n).result + launch_offset_x = MulIOp(launch_ivx_s, launch_tile_m_const).result + launch_offset_y = MulIOp(launch_ivy_s, launch_tile_n_const).result @herd( name="herd_0", @@ -544,10 +534,19 @@ def herd_body( print("Peano is needed for direct code generation mode.", file=sys.stderr) sys.exit(1) + # Hardware padding: round M, N up to a multiple of the launch tile. + # Aircc's air-split-launch-for-padding partitions the launch into + # interior + tail tiles when air.actual_sizes is set on air.launch. + launch_tile_m = args.tile_m * args.herd_m + launch_tile_n = args.tile_n * args.herd_n + m_padded = math.ceil(args.m / launch_tile_m) * launch_tile_m + n_padded = math.ceil(args.n / launch_tile_n) * launch_tile_n + needs_padding = (args.m != m_padded) or (args.n != n_padded) + mlir_module = build_module( - args.m, + m_padded, args.k, - args.n, + n_padded, args.tile_m, args.tile_k_l2, args.tile_k_l1, @@ -559,136 +558,63 @@ def herd_body( args.arch, ) - # Vectorization - only run if direct codegen mode is enabled + if needs_padding: + with mlir_module.context, Location.unknown(): + actual_sizes_attr = Attribute.parse(f"array") + found = [None] + + def _visit(op): + if op.operation.name == "air.launch": + op.operation.attributes["air.actual_sizes"] = actual_sizes_attr + found[0] = op + return WalkResult.INTERRUPT + return WalkResult.ADVANCE + + mlir_module.operation.walk(_visit) + assert found[0] is not None, "no air.launch op produced by build_module" + + # Direct-codegen flow: only the vectorize stages of the C++ orchestrator + # (tile-for-vectorize + vec-prep). All earlier phases are skipped. if args.direct_codegen: - transform_ir_string = """ - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func0 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op - - - %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_matmul, %vec_loops:3 = - transform.structured.tile_using_for %matmul tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_matmul tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op - - %linalg_fills = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %linalg_fills tile_sizes [0, 0, 1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - - - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op - - %herd1, %herd2, %herd3 = transform.split_handle %vectorized_herds : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %scf_fors = transform.structured.match ops{["scf.for"]} in %herd2 : (!transform.any_op) -> !transform.any_op - - %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func1 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op - - // Eliminate redundant vector.transfer_read operations - %func1_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func1_rematch : (!transform.any_op) -> !transform.any_op - - // Hoist loop-invariant vector transfers out of innermost loop - %herds_1 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %herd1_1, %herd2_1, %herd3_1 = transform.split_handle %herds_1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = i32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - - // Hoist all accumulator transfer pairs from the innermost loop - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op - - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op - - %fors_to_hoist_ptrs = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for1, %outer_fors1 = transform.split_handle %fors_to_hoist_ptrs {overflow_result = 1}: (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Hoist the 4 extsi/trunci pairs from the innermost loop - // Pattern: each iter has (2 i8→i16, 1 i16→i32) so total 12 extsi ops - // i16→i32 extsi ops are at indices 2, 5, 8, 11 (0-indexed) - %all_extsi_loop = transform.structured.match ops{["arith.extsi"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - %all_trunci_loop = transform.structured.match ops{["arith.trunci"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - - // Split to get individual operations (12 extsi total) - %e0, %e1, %extsi_i16_1, %e3, %e4, %extsi_i16_2, %e6, %e7, %extsi_i16_3, %e9, %e10, %extsi_i16_4 = transform.split_handle %all_extsi_loop {num_result_handles = 12} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // The 4 trunci ops correspond to the 4 vector.contract results - %trunci_1, %trunci_2, %trunci_3, %trunci_4 = transform.split_handle %all_trunci_loop {num_result_handles = 4} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // Hoist first pair (arg29 - index 2) - %for1_1_hoisted_1 = transform.air.hoist_cast_pair %extsi_i16_1, %trunci_1, %innermost_for1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - // Re-match and hoist second pair (arg30 - was index 5, now 4 after first hoist) - %all_extsi_loop_2 = transform.structured.match ops{["arith.extsi"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %all_trunci_loop_2 = transform.structured.match ops{["arith.trunci"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %e2_0, %e2_1, %e2_2, %e2_3, %extsi_i16_2_new, %e2_5, %e2_6, %e2_7, %e2_8, %e2_9, %e2_10 = transform.split_handle %all_extsi_loop_2 {num_result_handles = 11} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %trunci_2_1, %trunci_2_2, %trunci_2_3 = transform.split_handle %all_trunci_loop_2 {num_result_handles = 3} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_2 = transform.air.hoist_cast_pair %extsi_i16_2_new, %trunci_2_1, %for1_1_hoisted_1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - // Re-match and hoist third pair (arg31 - was index 8, now 6 after two hoists) - %all_extsi_loop_3 = transform.structured.match ops{["arith.extsi"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %all_trunci_loop_3 = transform.structured.match ops{["arith.trunci"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %e3_0, %e3_1, %e3_2, %e3_3, %e3_4, %e3_5, %extsi_i16_3_new, %e3_7, %e3_8, %e3_9 = transform.split_handle %all_extsi_loop_3 {num_result_handles = 10} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %trunci_3_1, %trunci_3_2 = transform.split_handle %all_trunci_loop_3 {num_result_handles = 2} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %for1_1_hoisted_3 = transform.air.hoist_cast_pair %extsi_i16_3_new, %trunci_3_1, %for1_1_hoisted_2 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - // Re-match and hoist fourth pair (arg32 - was index 11, now 8 after three hoists) - %all_extsi_loop_4 = transform.structured.match ops{["arith.extsi"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %all_trunci_loop_4 = transform.structured.match ops{["arith.trunci"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - // Now should have 8 i8→i16 extsi and 1 i16→i32 extsi remaining (9 total) - %e4_0, %e4_1, %e4_2, %e4_3, %e4_4, %e4_5, %e4_6, %e4_7, %extsi_i16_4_final = transform.split_handle %all_extsi_loop_4 {num_result_handles = 9} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_final = transform.air.hoist_cast_pair %extsi_i16_4_final, %all_trunci_loop_4, %for1_1_hoisted_3 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - - %func2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_3 = transform.air.fold_unit_extent_dims %func_fold_3 : (!transform.any_op) -> !transform.any_op - transform.yield - } - } - """ - transform_ir = Module.parse(transform_ir_string, context=mlir_module.context) - run_transform(transform_ir, mlir_module) + pipeline = ( + "builtin.module(" + + ",".join( + [ + "func.func(canonicalize,cse)", + "air-matmul-codegen{" + "matmul-vec-tile=2,2,1,0,0,0 " + "matmul-unroll-vec-tile=1,1,0,0,0,0 " + "matmul-unroll-factor=2 fill-vec-tile=0,0,1,1 " + "}", + "func.func(air-herd-vectorize)", + "func.func(canonicalize,cse,fold-memref-alias-ops)", + "air-matmul-codegen{" + "vec-prep-cast1-target-element-type=i32 " + "vec-prep-cast1-input-indices=2 " + "vec-prep-cast1-output-indices=0 " + "vec-prep-hoist-cast-pairs=true" + "}", + "func.func(canonicalize,cse,fold-memref-alias-ops)", + ] + ) + + ")" + ) + pm = air.passmanager.PassManager.parse(pipeline, context=mlir_module.context) + pm.run(mlir_module.operation) if args.print_module_only: print(mlir_module) exit(0) - input_a = np.arange(0, args.m * args.k, dtype=np.int64).reshape(args.m, args.k) % 7 - input_a = input_a.astype(INPUT_DATATYPE) - input_b = np.arange(0, args.k * args.n, dtype=np.int64).reshape(args.k, args.n) % 7 - input_b = input_b.astype(INPUT_DATATYPE) + # Buffers allocated at the padded shape; tail rows/cols stay zero so the + # matmul output's tail is also zero (and is not validated). + input_a = np.zeros((m_padded, args.k), dtype=INPUT_DATATYPE) + input_a[: args.m, :] = ( + np.arange(0, args.m * args.k, dtype=np.int64).reshape(args.m, args.k) % 7 + ).astype(INPUT_DATATYPE) + input_b = np.zeros((args.k, n_padded), dtype=INPUT_DATATYPE) + input_b[:, : args.n] = ( + np.arange(0, args.k * args.n, dtype=np.int64).reshape(args.k, args.n) % 7 + ).astype(INPUT_DATATYPE) if args.compile_mode == "compile-and-run": @@ -716,9 +642,10 @@ def herd_body( dtype=OUTPUT_DATATYPE, ) - # Store as a dictionary + # Output comes back at the padded shape; only validate the + # [0:M, 0:N] interior. sampled_data = { - "shape": (args.m, args.n), + "shape": (m_padded, n_padded), "indices": sampled_indices, "values": sampled_values, } diff --git a/test/xrt/37_matmul_transform_4x4_bf16/run.py b/test/xrt/37_matmul_transform_4x4_bf16/run.py index d950a6367..e2a496ae6 100644 --- a/test/xrt/37_matmul_transform_4x4_bf16/run.py +++ b/test/xrt/37_matmul_transform_4x4_bf16/run.py @@ -125,7 +125,10 @@ def forward(lhs, rhs): ## Tiling ################################################ -# Load the MLIR transform IR from an external file +# Drive matmul codegen via the transform script. transform_aie2p.mlir +# delegates to the C++ air-matmul-codegen orchestrator via +# transform.apply_registered_pass; transform_aie2.mlir is the legacy +# hand-rolled NPU1 path. with open(args.transform_script, "r") as f: transform_ir_string = f.read() transform_ir = Module.parse(transform_ir_string, context=context) diff --git a/test/xrt/37_matmul_transform_4x4_bf16/transform_aie2p.mlir b/test/xrt/37_matmul_transform_4x4_bf16/transform_aie2p.mlir index a40f6993d..47d9a8b21 100644 --- a/test/xrt/37_matmul_transform_4x4_bf16/transform_aie2p.mlir +++ b/test/xrt/37_matmul_transform_4x4_bf16/transform_aie2p.mlir @@ -1,203 +1,37 @@ -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. // SPDX-License-Identifier: MIT -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - // First level tile to forall. - %first_level_tiled_matmul, %outer_forall = - transform.structured.tile_using_forall %matmul tile_sizes [256, 256] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Fuse fill operation into the forall loop. - %fused_fill, %1 = transform.structured.fuse_into_containing_op %fill into %outer_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - // First level pack the matmul. - %first_level_tiled_transposed_l2_packed_matmul = transform.structured.pack %first_level_tiled_matmul packed_sizes = [64, 64, 64] - : (!transform.any_op) -> (!transform.any_op) - - %lhs_transposed_l2_pack_op = transform.get_producer_of_operand %first_level_tiled_transposed_l2_packed_matmul[0] : (!transform.any_op) -> (!transform.any_op) - %first_level_tiled_l2_packed_matmul, %lhs_l2_pack, %lhs_unpack = - transform.structured.pack_transpose %lhs_transposed_l2_pack_op with_compute_op(%first_level_tiled_transposed_l2_packed_matmul) - outer_perm = [0, 1] inner_perm = [0, 1] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %rhs_transposed_l2_pack_op = transform.get_producer_of_operand %first_level_tiled_l2_packed_matmul[1] : (!transform.any_op) -> (!transform.any_op) - %first_level_tiled_l2_packed_matmul_lhs_transposed, %rhs_l2_pack, %rhs_unpack = - transform.structured.pack_transpose %rhs_transposed_l2_pack_op with_compute_op(%first_level_tiled_l2_packed_matmul) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Run canonicalization - %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func1 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func1 : !transform.any_op - - // Promote the fused fill to shared memory - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - // Second level pack the matmul. - %generic_op = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %l1_packed = transform.structured.pack %generic_op packed_sizes = [0, 0, 0, 8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - // Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0] - %l1_packed_lhs = transform.get_producer_of_operand %l1_packed[0] - : (!transform.any_op) -> (!transform.any_op) - %lhs_l1_packed_matmul, %lhs_l1_pack_op, %lhs_l1_unpack_op = - transform.structured.pack_transpose %l1_packed_lhs with_compute_op(%l1_packed) - outer_perm = [0, 1, 3, 2] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Transpose B matrix from [K N k n n0 k0] to [K N n k k0 n0] - %l1_packed_rhs = transform.get_producer_of_operand %lhs_l1_packed_matmul[1] - : (!transform.any_op) -> (!transform.any_op) - %operands_l1_packed_matmul, %rhs_l1_pack_op, %rhs_l1_unpack_op = - transform.structured.pack_transpose %l1_packed_rhs with_compute_op(%lhs_l1_packed_matmul) - outer_perm = [0, 1, 3, 2] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0] - %l1_packed_output = transform.get_consumers_of_result %operands_l1_packed_matmul[0] - : (!transform.any_op) -> (!transform.any_op) - %l1_packed_matmul, %output_l1_pack_op, %output_l1_unpack_op = - transform.structured.pack_transpose %l1_packed_output with_compute_op(%operands_l1_packed_matmul) - outer_perm = [0, 1, 3, 2] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Promote the result to local memory - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %output_l1_pack_op - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - // First level for loop. - %first_level_tiled_reduction_matmul, %outer_for_loop = - transform.structured.tile_using_for %l1_packed_matmul tile_sizes [0, 0, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Fuse the pack operations in the outer for loop. - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %lhs_l1_pack_op into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %rhs_l1_pack_op into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_lhs_l2_pack, %4 = transform.structured.fuse_into_containing_op %lhs_l2_pack into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l2_pack, %5 = transform.structured.fuse_into_containing_op %rhs_l2_pack into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Promote the lhs to shared memory - %lhs_l2_pack_buffer, %lhs_l2_pack_new = transform.structured.bufferize_to_allocation %fused_lhs_l2_pack - {memory_space = 1, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - // Promote the rhs to shared memory - %rhs_l2_pack_buffer, %rhs_l2_pack_new = transform.structured.bufferize_to_allocation %fused_rhs_l2_pack - {memory_space = 1, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op +// +// AIE2P (Strix) two-pack-level matmul codegen via the C++ +// air-matmul-codegen orchestrator. M=512 N=512 K=1024. +// Per-launch matmul: 256x256x1024. - // Run canonicalization - %func2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func2 : !transform.any_op - - // Second level tile to forall with tile_sizes. - %second_level_tiled_matmul, %inner_forall = - transform.structured.tile_using_forall %first_level_tiled_reduction_matmul tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Fuse the pack operations in inner forall loop. - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Second level for loop. - %generic_op1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %second_level_tiled_reduction_matmul, %inner_for_loop = - transform.structured.tile_using_for %generic_op1 tile_sizes [0, 0, 0, 0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Fuse the pack operations in inner for loop. - %fused_lhs_l1_pack3, %8 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack2 into %inner_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack3, %9 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack2 into %inner_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Promote the LHS to local memory. - %lhs_l1_pack_buffer, %lhs_l1_pack_new = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack3 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - - // Promote the RHS to local memory. - %rhs_l1_pack_buffer, %rhs_l1_pack_new = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack3 - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - // Run canonicalization - %func3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func3 : !transform.any_op - - // Hoist static alloc out of the loops - %func8 = transform.structured.match ops{["func.func"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - transform.air.hoist_static_alloc %func8 : (!transform.any_op) -> () - - // Peel the for loop - %for_op = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.for"> - - // Find the producer operation (fill), and tile using for_all, as the prologue. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %fill_op tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Find the consumer operation (unpack), and tile using for_all, as the epilogue. - %unpack_ops = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %l1_to_l2_unpack, %l2_to_l3_unpack = transform.split_handle %unpack_ops : (!transform.any_op<"linalg.unpack">) -> (!transform.any_op<"linalg.unpack">, !transform.any_op<"linalg.unpack">) - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %l1_to_l2_unpack tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Run canonicalization - %func5 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func5 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func5 : !transform.any_op - - // Bufferize - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op - - // Run canonicalization to remove redundant memcpy (with linalg.generic form) ops created, which can be deleted by canonicalizer. We have to run it again because the memrefs are unified in CSE pass, so we can truely remove redundant memcpy. - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op - - // Tile linalg.generics for vectorization - %linalg_generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:6 = - transform.structured.tile_using_for %linalg_generics tile_sizes [1, 1, 1, 1, 1, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // Tile linalg.fills for vectorized write - %linalg_fills = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_fills, %vec_fill_loops:4 = - transform.structured.tile_using_for %linalg_fills tile_sizes [1, 1, 1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - transform.yield - } +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { + + transform.apply_registered_pass "air-matmul-codegen" with options = { + "launch-tile" = [256, 256], + "l2-pack-sizes" = [64, 64, 64], + "l2-lhs-outer-perm" = [0, 1], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [0, 1], "l2-acc-inner-perm" = [0, 1], + "bufferize-output-l2" = true, + "l1-pack-sizes" = [0, 0, 0, 8, 8, 8], + "l1-lhs-outer-perm" = [0, 1, 3, 2], + "l1-rhs-outer-perm" = [0, 1, 3, 2], "l1-rhs-inner-perm" = [1, 0], + "l1-acc-outer-perm" = [0, 1, 3, 2], + "outer-k-tile-factor" = 1, "outer-k-iter-index" = 2, + "core-tile" = [1, 1, 0, 0, 0, 0, 0, 0, 0], + "inner-k-tile-factor" = 8, "inner-k-iter-index" = 5, + "prologue-tile" = [1, 1], "epilogue-tile" = [1, 1], + "hoist-static-alloc-first" = true, + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [1, 1, 1, 1, 1, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [0, 0, 0, 0, 0, 0, 0, 0, 0], + "matmul-unroll-factor" = 1, + "fill-vec-tile" = [1, 1, 1, 1] + } to %arg1 : (!transform.any_op) -> !transform.any_op + + transform.yield + } } diff --git a/test/xrt/39_triton_matmul_ver3_vectorized/run.py b/test/xrt/39_triton_matmul_ver3_vectorized/run.py index 5099f2452..7d7c65394 100644 --- a/test/xrt/39_triton_matmul_ver3_vectorized/run.py +++ b/test/xrt/39_triton_matmul_ver3_vectorized/run.py @@ -25,7 +25,13 @@ type=str, dest="transform_script", default="transform.mlir", - help="Transform script path", + help="Transform script path (legacy path).", +) +parser.add_argument( + "--use-cpp-pipeline", + action="store_true", + help="Replace the legacy transform script with the C++ matmul codegen " + "orchestrator (air-matmul-codegen). Targets aie2 / NPU1 (mmul=4x4x8).", ) args = parser.parse_args() @@ -84,11 +90,55 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) - # Load the MLIR transform IR from an external file - with open(args.transform_script, "r") as f: - transform_ir_string = f.read() - transform_ir = Module.parse(transform_ir_string) - run_transform(transform_ir, air_module) + if args.use_cpp_pipeline: + # Single-pack-level NPU1 (aie2) flow via the C++ orchestrator. + # mmul=[4,4,8]. Per-launch matmul is 256x256x512; orchestrator's + # launch-tile=64,64 creates an outer scf.forall (4x4 herd) wrapping + # an inner 64x64 matmul. No L3->L2 copy tiling, no fuse-truncf + # (output is f32). No prologue/epilogue tiling (test 39's transform + # script doesn't separate them). + cpp_pipeline = ( + "builtin.module(" + "air-matmul-codegen{" + # Phase A: launch-tile = 64x64 (the only parallel tile in this + # flow). Becomes the outer scf.forall, mapped to a 4x4 herd. + "launch-tile=64,64 " + # Phase C: bufferize fill output to L2. + "bufferize-output-l2=true " + # Phase B: single-pack [4, 4, 8] (aie2 mmul). + "l2-pack-sizes=4,4,8 " + "l2-lhs-outer-perm=1,0 " + "l2-rhs-outer-perm=1,0 l2-rhs-inner-perm=1,0 " + "l2-acc-outer-perm=1,0 " + # Phase E: K-tile factor=4 (matches transform's tile_using_for " + # [0, 0, 4]). + "outer-k-tile-factor=4 outer-k-iter-index=2 " + # No core-tile (the launch-tile is the only parallel tile). + # No inner K-tile, no prologue/epilogue. + # Phase L: upstream one-shot-bufferize. + "one-shot-bufferize=true " + # Phase M: tile-for-vectorize at [1, 1, 1, 0, 0, 0]; no second- + # level unroll. + "matmul-vec-tile=1,1,1,0,0,0 " + "matmul-unroll-factor=1 fill-vec-tile=1,1 " + # Phase N: no vec-prep (test 39 doesn't run any vec-prep steps). + "}, " + "func.func(scf-forall-to-parallel), " + "air-par-to-herd, " + "func.func(air-herd-vectorize), " + "func.func(canonicalize,cse,fold-memref-alias-ops), " + # Cleanup orchestrator pass after vectorization. + "air-matmul-codegen{}" + ")" + ) + pm = air.passmanager.PassManager.parse(cpp_pipeline) + pm.run(air_module.operation) + else: + # Load the MLIR transform IR from an external file + with open(args.transform_script, "r") as f: + transform_ir_string = f.read() + transform_ir = Module.parse(transform_ir_string) + run_transform(transform_ir, air_module) ################################################ ## Binding scf.paralell to air hierarchies diff --git a/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/run.py b/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/run.py index f09fa59b7..9845c96b3 100644 --- a/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/run.py +++ b/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/run.py @@ -25,7 +25,7 @@ type=str, dest="transform_script", default="transform.mlir", - help="Transform script path", + help="Transform script path (legacy path).", ) parser.add_argument( "--output-format", @@ -92,7 +92,8 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) - # Load the MLIR transform IR from an external file + # Drive matmul codegen via the transform script (which delegates to the + # C++ air-matmul-codegen orchestrator via transform.apply_registered_pass). with open(args.transform_script, "r") as f: transform_ir_string = f.read() transform_ir = Module.parse(transform_ir_string) diff --git a/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2.mlir b/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2.mlir index 7137bb885..0442ad39e 100644 --- a/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2.mlir +++ b/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2.mlir @@ -1,354 +1,73 @@ -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. // SPDX-License-Identifier: MIT - -//////////////////////////////////////////////////////////////////////////////// -// Transform Script for Matmul (Triton Ver4, Vectorized): Step-by-Step Annotated -// This script transforms a matmul IR into a tiled, packed, bufferized, and -// hardware-friendly form suitable for AIE execution. Each step is annotated -// with its purpose, assumptions, and relation to the IR. // -// Target configuration: 8x4 AIE core array (Phoenix) -// Data types: BF16 inputs, F32 accumulation -//////////////////////////////////////////////////////////////////////////////// +// AIE2 (Phoenix) single-pack-level f32-out matmul codegen via the C++ +// air-matmul-codegen orchestrator. mmul=4x4x8, core-tile=16x16. module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3->L2 MEMORY COPIES - // Purpose: Tile the memref copy ops that move data from L3 (DDR) to L2 (shared memory). - //========================================================================== - - // Step 1: Convert memref.copy to linalg.copy and tile for L3->L2 data movement. - // Purpose: Transforms memref copies into tileable linalg operations for streaming data. - // Assumption: The IR contains memref.copy ops for A and B matrices. - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: MATCH AND PREPARE CORE OPERATIONS - // Purpose: Identify fill and matmul operations, promote output to L2. - //========================================================================== - - // Step 2: Match the fill and matmul ops. - // Assumption: The IR contains linalg.fill and linalg.matmul ops representing - // initialization and main computation. - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - // Step 3: Promote the result buffer (C matrix) to L2 shared memory. - // Purpose: Allocate output buffer in L2 for accumulation before writing back to L3. - // memory_space = 1 corresponds to L2 (shared memory). - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, memcpy = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Purpose: Apply data tiling (packing) to enable efficient vectorized computation. - //========================================================================== - - // Step 4: Pack matmul with tile sizes [4, 4, 8]. - // Purpose: Transforms linalg.matmul into linalg.generic with packed layout. - // Assumption: Pack sizes [4, 4, 8] correspond to M, N, K tile dimensions for - // efficient AIE vector unit utilization. - %packed = transform.structured.pack %matmul packed_sizes = [4, 4, 8] - : (!transform.any_op) -> (!transform.any_op) - - // Step 5: Transpose A matrix for packed layout. - // Purpose: Ensures A operand has correct memory layout for vectorized access. - // Outer permutation [1, 0] swaps the outer tile dimensions. - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 6: Transpose B matrix for packed layout. - // Purpose: Ensures B operand has correct memory layout for vectorized access. - // Both outer_perm and inner_perm [1, 0] transpose outer and inner tile dimensions. - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 7: Transpose C matrix for packed layout. - // Purpose: Ensures C operand has correct memory layout matching A and B. - // Outer permutation [1, 0] aligns output tile dimensions. - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 8: Promote the output pack operation to L1 local memory. - // Purpose: Allocate L1 buffer for C matrix tiles during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 4: TILE REDUCTION AND FUSE PACK OPERATIONS - // Purpose: Tile the K dimension and fuse data movement into compute loops. - //========================================================================== - - // Step 9: Tile the reduction (K) dimension. - // Purpose: Enables streaming of A and B tiles along K dimension. - // Tile size [0, 0, 8] tiles only the K dimension with factor 8. - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - // Step 10: Fuse pack operations for A and B into the outer K-loop. - // Purpose: Moves data packing inside the loop for better locality and pipelining. - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // Purpose: Create parallel loops for mapping to 8x4 AIE core array. - //========================================================================== - - // Step 11: Tile matmul using scf.forall with tile size [16, 16, 0]. - // Purpose: Introduces parallelism across M and N dimensions for multi-core execution. - // Tile sizes [16, 16, 0] create 4x4 tiles for each AIE core to process. - %matmul_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [16, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op - - // Step 12: Fuse pack operations into the inner parallel loop. - // Purpose: Ensures each core has its own data packing for independent execution. - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 13: Canonicalization and CSE after tiling. - // Purpose: Cleans up IR, merges redundant ops, and prepares for further transforms. - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op - - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - // Purpose: Move input data to L1, create tiled fill (prologue) and unpack (epilogue). - //========================================================================== - - // Step 14: Promote input operands (A and B tiles) to L1 local memory. - // Purpose: Allocates L1 buffers for fast access during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - - // Step 15: Create tiled prologue (fill operation). - // Purpose: Initializes output buffers in parallel across cores. - // Generalize fill to generic, interchange dimensions, then tile with forall. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [16, 16] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op - - // Step 16: Create tiled epilogue (unpack operation). - // Purpose: Unpacks and writes results back to L2 in parallel across cores. - // Tile sizes [64, 64] match the L2 tile dimensions. - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op - - // Step 17: Canonicalization and CSE after buffer promotion. - // Purpose: Merges redundant allocs/copies and simplifies the IR. - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op - - //========================================================================== - // PHASE 7: BUFFERIZATION AND AIR CLEANUP - // Purpose: Convert tensors to memrefs and optimize memory operations. - //========================================================================== - - // Step 18: One-shot bufferization of the function. - // Purpose: Converts all remaining tensors to memrefs for hardware execution. - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op - - // Step 19: AIR-specific cleanup and memory optimization. - // Purpose: Removes uninitialized copies and eliminates redundant cascade memcpy patterns. - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - // Purpose: Fuse L3->L2 copy loops with main compute loop for double buffering. - //========================================================================== - - // Step 20: Fuse L3->L2 copy loops with the main K-reduction loop. - // Purpose: Expose L2 pingpong buffering opportunity by interleaving L3->L2 data transfer with L2->L1. - // Use annotation-based matching instead of fragile split_handle. - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op // Fold affine apply into for loop bound - transform.apply_cse to %func_op_updated_1 : !transform.any_op // Ensure loop bounds use shared cst ssa values - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - // Purpose: Final tiling to enable efficient vectorized execution on AIE vector units. - //========================================================================== - - // Step 21: Tile linalg.generic (matmul) for vectorization. - // Purpose: Creates inner loops with sizes suitable for vector register usage. - // Tile sizes [2, 2, 1, 0, 0, 0] unroll M and N by 2 for register blocking. - // Use annotation-based matching instead of fragile split_handle. - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // Step 22: Further tile and unroll innermost loops for full vectorization. - // Purpose: Completely unrolls the innermost M and N loops for register allocation. - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op - - // Step 23: Tile linalg.generic (fill) for vectorized initialization. - // Purpose: Creates vector-sized tiles for efficient zero-initialization. - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - // Purpose: Map parallel loops to AIE cores (herds) and apply vectorization. - //========================================================================== - - // Step 24: Convert scf.forall loops to AIE herd operations. - // Purpose: Maps parallel work to the 8x4 AIE core array. - // Each forall becomes an air.herd representing multi-core execution. - // Use annotation-based matching instead of fragile split_handle. - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op - - // Step 25: Apply vectorization to AIE herds. - // Purpose: Converts scalar operations to vector operations for AIE vector units. - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op - - // Step 26: Canonicalization after vectorization. - // Purpose: Simplifies vector operations and folds unit extent dimensions. - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op - - // Step 27: Eliminate redundant vector.transfer_read operations. - // Purpose: Removes duplicate memory reads for better performance. - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - // Purpose: Move vector reads/writes out of innermost loops for register reuse. - //========================================================================== - - // Step 28: Match the compute herd and prepare for hoisting optimization. - // Purpose: Identifies the compute herd and its vector operations for register optimization. - // Use annotation-based matching instead of fragile split_handle. - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op - - // Step 29: Identify the innermost loop for hoisting. - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 31: Cast vector types for correct accumulation precision. - // Purpose: Ensures vector.contract uses F32 for accumulation (BF16 inputs -> F32 output). - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - - // Step 32: Hoist all accumulator transfer pairs from innermost loop. - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op - - // Step 33: Flatten loop iteration arguments and hoist vector transfer pointers. - // Purpose: Simplifies loop structure and moves pointer computations out of loops. - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op - - // Step 34: Final canonicalization pass. - // Purpose: Cleans up the final IR for AIR/AIE lowering. - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { + + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 64, + "l2-pack-sizes" = [4, 4, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 8, "outer-k-iter-index" = 2, + "core-tile" = [16, 16, 0], + "prologue-tile" = [16, 16], "epilogue-tile" = [64, 64], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op + + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op + + %func3a = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %func3a + : (!transform.any_op) -> !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func3b + : (!transform.any_op) -> !transform.any_op + %func3c = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c + : (!transform.any_op) -> !transform.any_op + + %m3 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "f32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0] + } to %m2 : (!transform.any_op) -> !transform.any_op + + %func4a = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %func4a + : (!transform.any_op) -> !transform.any_op + %func4b = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b + : (!transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2p.mlir b/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2p.mlir index 1551daad3..b3190de6f 100644 --- a/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2p.mlir +++ b/test/xrt/44_triton_matmul_ver4_vector_ptr_opt/transform_aie2p.mlir @@ -1,354 +1,95 @@ -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. // SPDX-License-Identifier: MIT - -//////////////////////////////////////////////////////////////////////////////// -// Transform Script for Matmul (Triton Ver4, Vectorized): Step-by-Step Annotated -// This script transforms a matmul IR into a tiled, packed, bufferized, and -// hardware-friendly form suitable for AIE execution. Each step is annotated -// with its purpose, assumptions, and relation to the IR. // -// Target configuration: 8x4 AIE core array (Strix) -// Data types: BF16 inputs, F32 accumulation -//////////////////////////////////////////////////////////////////////////////// +// AIE2P (Strix) single-pack-level f32-out matmul codegen via the C++ +// air-matmul-codegen orchestrator. mmul=8x8x8, core-tile=8x8. module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3->L2 MEMORY COPIES - // Purpose: Tile the memref copy ops that move data from L3 (DDR) to L2 (shared memory). - //========================================================================== - - // Step 1: Convert memref.copy to linalg.copy and tile for L3->L2 data movement. - // Purpose: Transforms memref copies into tileable linalg operations for streaming data. - // Assumption: The IR contains memref.copy ops for A and B matrices. - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: MATCH AND PREPARE CORE OPERATIONS - // Purpose: Identify fill and matmul operations, promote output to L2. - //========================================================================== - - // Step 2: Match the fill and matmul ops. - // Assumption: The IR contains linalg.fill and linalg.matmul ops representing - // initialization and main computation. - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - // Step 3: Promote the result buffer (C matrix) to L2 shared memory. - // Purpose: Allocate output buffer in L2 for accumulation before writing back to L3. - // memory_space = 1 corresponds to L2 (shared memory). - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Purpose: Apply data tiling (packing) to enable efficient vectorized computation. - //========================================================================== - - // Step 4: Pack matmul with tile sizes [8, 8, 8]. - // Purpose: Transforms linalg.matmul into linalg.generic with packed layout. - // Assumption: Pack sizes [8, 8, 8] correspond to M, N, K tile dimensions for - // efficient AIE vector unit utilization. - %packed = transform.structured.pack %matmul packed_sizes = [8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - // Step 5: Transpose A matrix for packed layout. - // Purpose: Ensures A operand has correct memory layout for vectorized access. - // Outer permutation [1, 0] swaps the outer tile dimensions. - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 6: Transpose B matrix for packed layout. - // Purpose: Ensures B operand has correct memory layout for vectorized access. - // Both outer_perm and inner_perm [1, 0] transpose outer and inner tile dimensions. - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 7: Transpose C matrix for packed layout. - // Purpose: Ensures C operand has correct memory layout matching A and B. - // Outer permutation [1, 0] aligns output tile dimensions. - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 8: Promote the output pack operation to L1 local memory. - // Purpose: Allocate L1 buffer for C matrix tiles during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 4: TILE REDUCTION AND FUSE PACK OPERATIONS - // Purpose: Tile the K dimension and fuse data movement into compute loops. - //========================================================================== - - // Step 9: Tile the reduction (K) dimension. - // Purpose: Enables streaming of A and B tiles along K dimension. - // Tile size [0, 0, 8] tiles only the K dimension with factor 8. - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - // Step 10: Fuse pack operations for A and B into the outer K-loop. - // Purpose: Moves data packing inside the loop for better locality and pipelining. - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // Purpose: Create parallel loops for mapping to 8x4 AIE core array. - //========================================================================== + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { - // Step 11: Tile matmul using scf.forall with tile size [8, 8, 0]. - // Purpose: Introduces parallelism across M and N dimensions for multi-core execution. - // Tile sizes [8, 8, 0] create 8x8 tiles for each AIE core to process. - %matmul_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [8, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 64, + "l2-pack-sizes" = [8, 8, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 8, "outer-k-iter-index" = 2, + "core-tile" = [8, 8, 0], + "prologue-tile" = [8, 8], "epilogue-tile" = [64, 64], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op - // Step 12: Fuse pack operations into the inner parallel loop. - // Purpose: Ensures each core has its own data packing for independent execution. - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op - // Step 13: Canonicalization and CSE after tiling. - // Purpose: Cleans up IR, merges redundant ops, and prepares for further transforms. - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op + %func3a = transform.structured.match ops{["func.func"]} in %m2 - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - // Purpose: Move input data to L1, create tiled fill (prologue) and unpack (epilogue). - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 14: Promote input operands (A and B tiles) to L1 local memory. - // Purpose: Allocates L1 buffers for fast access during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.apply_registered_pass "canonicalize" to %func3a - // Step 15: Create tiled prologue (fill operation). - // Purpose: Initializes output buffers in parallel across cores. - // Generalize fill to generic, interchange dimensions, then tile with forall. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [8, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 16: Create tiled epilogue (unpack operation). - // Purpose: Unpacks and writes results back to L2 in parallel across cores. - // Tile sizes [64, 64] match the L2 tile dimensions. - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m2 - // Step 17: Canonicalization and CSE after buffer promotion. - // Purpose: Merges redundant allocs/copies and simplifies the IR. - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 7: BUFFERIZATION AND AIR CLEANUP - // Purpose: Convert tensors to memrefs and optimize memory operations. - //========================================================================== + transform.apply_registered_pass "cse" to %func3b - // Step 18: One-shot bufferization of the function. - // Purpose: Converts all remaining tensors to memrefs for hardware execution. - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 19: AIR-specific cleanup and memory optimization. - // Purpose: Removes uninitialized copies and eliminates redundant cascade memcpy patterns. - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op + %func3c = transform.structured.match ops{["func.func"]} in %m2 - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - // Purpose: Fuse L3->L2 copy loops with main compute loop for double buffering. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 20: Fuse L3->L2 copy loops with the main K-reduction loop. - // Purpose: Expose L2 pingpong buffering opportunity by interleaving L3->L2 data transfer with L2->L1. - // Use annotation-based matching instead of fragile split_handle. - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op // Fold affine apply into for loop bound - transform.apply_cse to %func_op_updated_1 : !transform.any_op // Ensure loop bounds use shared cst ssa values - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - // Purpose: Final tiling to enable efficient vectorized execution on AIE vector units. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 21: Tile linalg.generic (matmul) for vectorization. - // Purpose: Creates inner loops with sizes suitable for vector register usage. - // Tile sizes [2, 2, 1, 0, 0, 0] unroll M and N by 2 for register blocking. - // Use annotation-based matching instead of fragile split_handle. - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %m3 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "f32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0] + } to %m2 : (!transform.any_op) -> !transform.any_op - // Step 22: Further tile and unroll innermost loops for full vectorization. - // Purpose: Completely unrolls the innermost M and N loops for register allocation. - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op + %func4a = transform.structured.match ops{["func.func"]} in %m3 - // Step 23: Tile linalg.generic (fill) for vectorized initialization. - // Purpose: Creates vector-sized tiles for efficient zero-initialization. - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - // Purpose: Map parallel loops to AIE cores (herds) and apply vectorization. - //========================================================================== + transform.apply_registered_pass "canonicalize" to %func4a - // Step 24: Convert scf.forall loops to AIE herd operations. - // Purpose: Maps parallel work to the 8x4 AIE core array. - // Each forall becomes an air.herd representing multi-core execution. - // Use annotation-based matching instead of fragile split_handle. - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 25: Apply vectorization to AIE herds. - // Purpose: Converts scalar operations to vector operations for AIE vector units. - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op + %func4b = transform.structured.match ops{["func.func"]} in %m3 - // Step 26: Canonicalization after vectorization. - // Purpose: Simplifies vector operations and folds unit extent dimensions. - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 27: Eliminate redundant vector.transfer_read operations. - // Purpose: Removes duplicate memory reads for better performance. - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - // Purpose: Move vector reads/writes out of innermost loops for register reuse. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 28: Match the compute herd and prepare for hoisting optimization. - // Purpose: Identifies the compute herd and its vector operations for register optimization. - // Use annotation-based matching instead of fragile split_handle. - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m3 - // Step 29: Identify the innermost loop for hoisting. - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 31: Cast vector types for correct accumulation precision. - // Purpose: Ensures vector.contract uses F32 for accumulation (BF16 inputs -> F32 output). - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - - // Step 32: Hoist all accumulator transfer pairs from innermost loop. - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 33: Flatten loop iteration arguments and hoist vector transfer pointers. - // Purpose: Simplifies loop structure and moves pointer computations out of loops. - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c - // Step 34: Final canonicalization pass. - // Purpose: Cleans up the final IR for AIR/AIE lowering. - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/test/xrt/45_triton_matmul_ver4_strix_8x4/run.py b/test/xrt/45_triton_matmul_ver4_strix_8x4/run.py index 68099d80c..04a3f25e7 100644 --- a/test/xrt/45_triton_matmul_ver4_strix_8x4/run.py +++ b/test/xrt/45_triton_matmul_ver4_strix_8x4/run.py @@ -25,7 +25,7 @@ type=str, dest="transform_script", default="transform.mlir", - help="Transform script path", + help="Transform script path (legacy path).", ) parser.add_argument( "--output-format", @@ -93,7 +93,8 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) - # Load the MLIR transform IR from an external file + # Drive matmul codegen via the transform script. The script wraps the + # C++ air-matmul-codegen orchestrator via transform.apply_registered_pass. with open(args.transform_script, "r") as f: transform_ir_string = f.read() transform_ir = Module.parse(transform_ir_string) diff --git a/test/xrt/45_triton_matmul_ver4_strix_8x4/transform_aie2p.mlir b/test/xrt/45_triton_matmul_ver4_strix_8x4/transform_aie2p.mlir index 1551daad3..e5932fa7c 100644 --- a/test/xrt/45_triton_matmul_ver4_strix_8x4/transform_aie2p.mlir +++ b/test/xrt/45_triton_matmul_ver4_strix_8x4/transform_aie2p.mlir @@ -1,354 +1,102 @@ -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. // SPDX-License-Identifier: MIT - -//////////////////////////////////////////////////////////////////////////////// -// Transform Script for Matmul (Triton Ver4, Vectorized): Step-by-Step Annotated -// This script transforms a matmul IR into a tiled, packed, bufferized, and -// hardware-friendly form suitable for AIE execution. Each step is annotated -// with its purpose, assumptions, and relation to the IR. // -// Target configuration: 8x4 AIE core array (Strix) -// Data types: BF16 inputs, F32 accumulation -//////////////////////////////////////////////////////////////////////////////// +// Drives the C++ air-matmul-codegen orchestrator through the transform +// dialect. The matmul-specific tile/pack/bufferize/vectorize work is +// delegated to the orchestrator; the transform script keeps the +// non-matmul plumbing (scf.forall->herd, herd-vectorize, cleanup). module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3->L2 MEMORY COPIES - // Purpose: Tile the memref copy ops that move data from L3 (DDR) to L2 (shared memory). - //========================================================================== - - // Step 1: Convert memref.copy to linalg.copy and tile for L3->L2 data movement. - // Purpose: Transforms memref copies into tileable linalg operations for streaming data. - // Assumption: The IR contains memref.copy ops for A and B matrices. - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: MATCH AND PREPARE CORE OPERATIONS - // Purpose: Identify fill and matmul operations, promote output to L2. - //========================================================================== - - // Step 2: Match the fill and matmul ops. - // Assumption: The IR contains linalg.fill and linalg.matmul ops representing - // initialization and main computation. - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - // Step 3: Promote the result buffer (C matrix) to L2 shared memory. - // Purpose: Allocate output buffer in L2 for accumulation before writing back to L3. - // memory_space = 1 corresponds to L2 (shared memory). - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Purpose: Apply data tiling (packing) to enable efficient vectorized computation. - //========================================================================== - - // Step 4: Pack matmul with tile sizes [8, 8, 8]. - // Purpose: Transforms linalg.matmul into linalg.generic with packed layout. - // Assumption: Pack sizes [8, 8, 8] correspond to M, N, K tile dimensions for - // efficient AIE vector unit utilization. - %packed = transform.structured.pack %matmul packed_sizes = [8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - // Step 5: Transpose A matrix for packed layout. - // Purpose: Ensures A operand has correct memory layout for vectorized access. - // Outer permutation [1, 0] swaps the outer tile dimensions. - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 6: Transpose B matrix for packed layout. - // Purpose: Ensures B operand has correct memory layout for vectorized access. - // Both outer_perm and inner_perm [1, 0] transpose outer and inner tile dimensions. - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 7: Transpose C matrix for packed layout. - // Purpose: Ensures C operand has correct memory layout matching A and B. - // Outer permutation [1, 0] aligns output tile dimensions. - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 8: Promote the output pack operation to L1 local memory. - // Purpose: Allocate L1 buffer for C matrix tiles during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 4: TILE REDUCTION AND FUSE PACK OPERATIONS - // Purpose: Tile the K dimension and fuse data movement into compute loops. - //========================================================================== - - // Step 9: Tile the reduction (K) dimension. - // Purpose: Enables streaming of A and B tiles along K dimension. - // Tile size [0, 0, 8] tiles only the K dimension with factor 8. - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - // Step 10: Fuse pack operations for A and B into the outer K-loop. - // Purpose: Moves data packing inside the loop for better locality and pipelining. - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // Purpose: Create parallel loops for mapping to 8x4 AIE core array. - //========================================================================== + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { - // Step 11: Tile matmul using scf.forall with tile size [8, 8, 0]. - // Purpose: Introduces parallelism across M and N dimensions for multi-core execution. - // Tile sizes [8, 8, 0] create 8x8 tiles for each AIE core to process. - %matmul_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [8, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op + // Phase 1: matmul codegen orchestrator (pre-vectorize half). + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 64, + "l2-pack-sizes" = [8, 8, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 8, "outer-k-iter-index" = 2, + "core-tile" = [8, 8, 0], + "prologue-tile" = [8, 8], "epilogue-tile" = [64, 64], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op - // Step 12: Fuse pack operations into the inner parallel loop. - // Purpose: Ensures each core has its own data packing for independent execution. - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + // Phase 2: scf.forall -> scf.parallel -> air.herd, then vectorize herds. + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m3 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + %m4 = transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op - // Step 13: Canonicalization and CSE after tiling. - // Purpose: Cleans up IR, merges redundant ops, and prepares for further transforms. - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op + // Cleanup between vectorize and vec-prep. + %func3a = transform.structured.match ops{["func.func"]} in %m3 - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - // Purpose: Move input data to L1, create tiled fill (prologue) and unpack (epilogue). - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 14: Promote input operands (A and B tiles) to L1 local memory. - // Purpose: Allocates L1 buffers for fast access during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.apply_registered_pass "canonicalize" to %func3a - // Step 15: Create tiled prologue (fill operation). - // Purpose: Initializes output buffers in parallel across cores. - // Generalize fill to generic, interchange dimensions, then tile with forall. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [8, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 16: Create tiled epilogue (unpack operation). - // Purpose: Unpacks and writes results back to L2 in parallel across cores. - // Tile sizes [64, 64] match the L2 tile dimensions. - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m3 - // Step 17: Canonicalization and CSE after buffer promotion. - // Purpose: Merges redundant allocs/copies and simplifies the IR. - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 7: BUFFERIZATION AND AIR CLEANUP - // Purpose: Convert tensors to memrefs and optimize memory operations. - //========================================================================== + transform.apply_registered_pass "cse" to %func3b - // Step 18: One-shot bufferization of the function. - // Purpose: Converts all remaining tensors to memrefs for hardware execution. - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 19: AIR-specific cleanup and memory optimization. - // Purpose: Removes uninitialized copies and eliminates redundant cascade memcpy patterns. - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op + %func3c = transform.structured.match ops{["func.func"]} in %m3 - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - // Purpose: Fuse L3->L2 copy loops with main compute loop for double buffering. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 20: Fuse L3->L2 copy loops with the main K-reduction loop. - // Purpose: Expose L2 pingpong buffering opportunity by interleaving L3->L2 data transfer with L2->L1. - // Use annotation-based matching instead of fragile split_handle. - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op // Fold affine apply into for loop bound - transform.apply_cse to %func_op_updated_1 : !transform.any_op // Ensure loop bounds use shared cst ssa values - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - // Purpose: Final tiling to enable efficient vectorized execution on AIE vector units. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 21: Tile linalg.generic (matmul) for vectorization. - // Purpose: Creates inner loops with sizes suitable for vector register usage. - // Tile sizes [2, 2, 1, 0, 0, 0] unroll M and N by 2 for register blocking. - // Use annotation-based matching instead of fragile split_handle. - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + // Phase 3: matmul codegen orchestrator (vec-prep half). + %m5 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "f32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0] + } to %m3 : (!transform.any_op) -> !transform.any_op - // Step 22: Further tile and unroll innermost loops for full vectorization. - // Purpose: Completely unrolls the innermost M and N loops for register allocation. - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op + // Final cleanup. + %func4a = transform.structured.match ops{["func.func"]} in %m5 - // Step 23: Tile linalg.generic (fill) for vectorized initialization. - // Purpose: Creates vector-sized tiles for efficient zero-initialization. - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - // Purpose: Map parallel loops to AIE cores (herds) and apply vectorization. - //========================================================================== + transform.apply_registered_pass "canonicalize" to %func4a - // Step 24: Convert scf.forall loops to AIE herd operations. - // Purpose: Maps parallel work to the 8x4 AIE core array. - // Each forall becomes an air.herd representing multi-core execution. - // Use annotation-based matching instead of fragile split_handle. - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 25: Apply vectorization to AIE herds. - // Purpose: Converts scalar operations to vector operations for AIE vector units. - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op + %func4b = transform.structured.match ops{["func.func"]} in %m5 - // Step 26: Canonicalization after vectorization. - // Purpose: Simplifies vector operations and folds unit extent dimensions. - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 27: Eliminate redundant vector.transfer_read operations. - // Purpose: Removes duplicate memory reads for better performance. - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - // Purpose: Move vector reads/writes out of innermost loops for register reuse. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 28: Match the compute herd and prepare for hoisting optimization. - // Purpose: Identifies the compute herd and its vector operations for register optimization. - // Use annotation-based matching instead of fragile split_handle. - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m5 - // Step 29: Identify the innermost loop for hoisting. - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 31: Cast vector types for correct accumulation precision. - // Purpose: Ensures vector.contract uses F32 for accumulation (BF16 inputs -> F32 output). - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - - // Step 32: Hoist all accumulator transfer pairs from innermost loop. - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 33: Flatten loop iteration arguments and hoist vector transfer pointers. - // Purpose: Simplifies loop structure and moves pointer computations out of loops. - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c - // Step 34: Final canonicalization pass. - // Purpose: Cleans up the final IR for AIR/AIE lowering. - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/run.py b/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/run.py index 83c7cdf03..7007e324a 100644 --- a/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/run.py +++ b/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/run.py @@ -32,7 +32,7 @@ type=str, dest="transform_script", default="transform.mlir", - help="Transform script path", + help="Transform script path (legacy path).", ) parser.add_argument( "--compile-only", @@ -85,7 +85,8 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) - # Load the MLIR transform IR from an external file + # Drive matmul codegen via the transform script (delegates to the C++ + # air-matmul-codegen orchestrator via transform.apply_registered_pass). with open(args.transform_script, "r") as f: transform_ir_string = f.read() transform_ir = Module.parse(transform_ir_string) diff --git a/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/transform_aie2p.mlir b/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/transform_aie2p.mlir index 593df461b..2a2511d60 100644 --- a/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/transform_aie2p.mlir +++ b/test/xrt/46_triton_matmul_ver4_strix_8x4_i8_i8_i32/transform_aie2p.mlir @@ -1,354 +1,95 @@ // Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. // SPDX-License-Identifier: MIT - -//////////////////////////////////////////////////////////////////////////////// -// Transform Script for Matmul (Triton Ver4, Vectorized): Step-by-Step Annotated -// This script transforms a matmul IR into a tiled, packed, bufferized, and -// hardware-friendly form suitable for AIE execution. Each step is annotated -// with its purpose, assumptions, and relation to the IR. // -// Target configuration: 8x4 AIE core array (Strix) -// Data types: INT8 inputs, INT32 accumulation -//////////////////////////////////////////////////////////////////////////////// +// AIE2P (Strix) single-pack i8/i8/i32 matmul codegen via the C++ +// air-matmul-codegen orchestrator. mmul=8x8x8, i32 accumulation. module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3->L2 MEMORY COPIES - // Purpose: Tile the memref copy ops that move data from L3 (DDR) to L2 (shared memory). - //========================================================================== - - // Step 1: Convert memref.copy to linalg.copy and tile for L3->L2 data movement. - // Purpose: Transforms memref copies into tileable linalg operations for streaming data. - // Assumption: The IR contains memref.copy ops for A and B matrices. - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: MATCH AND PREPARE CORE OPERATIONS - // Purpose: Identify fill and matmul operations, promote output to L2. - //========================================================================== - - // Step 2: Match the fill and matmul ops. - // Assumption: The IR contains linalg.fill and linalg.matmul ops representing - // initialization and main computation. - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - // Step 3: Promote the result buffer (C matrix) to L2 shared memory. - // Purpose: Allocate output buffer in L2 for accumulation before writing back to L3. - // memory_space = 1 corresponds to L2 (shared memory). - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Purpose: Apply data tiling (packing) to enable efficient vectorized computation. - //========================================================================== - - // Step 4: Pack matmul with tile sizes [8, 8, 8]. - // Purpose: Transforms linalg.matmul into linalg.generic with packed layout. - // Assumption: Pack sizes [8, 8, 8] correspond to M, N, K tile dimensions for - // efficient AIE vector unit utilization. - %packed = transform.structured.pack %matmul packed_sizes = [8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - // Step 5: Transpose A matrix for packed layout. - // Purpose: Ensures A operand has correct memory layout for vectorized access. - // Outer permutation [1, 0] swaps the outer tile dimensions. - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 6: Transpose B matrix for packed layout. - // Purpose: Ensures B operand has correct memory layout for vectorized access. - // Both outer_perm and inner_perm [1, 0] transpose outer and inner tile dimensions. - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 7: Transpose C matrix for packed layout. - // Purpose: Ensures C operand has correct memory layout matching A and B. - // Outer permutation [1, 0] aligns output tile dimensions. - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 8: Promote the output pack operation to L1 local memory. - // Purpose: Allocate L1 buffer for C matrix tiles during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 4: TILE REDUCTION AND FUSE PACK OPERATIONS - // Purpose: Tile the K dimension and fuse data movement into compute loops. - //========================================================================== - - // Step 9: Tile the reduction (K) dimension. - // Purpose: Enables streaming of A and B tiles along K dimension. - // Tile size [0, 0, 8] tiles only the K dimension with factor 8. - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - // Step 10: Fuse pack operations for A and B into the outer K-loop. - // Purpose: Moves data packing inside the loop for better locality and pipelining. - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // Purpose: Create parallel loops for mapping to 8x4 AIE core array. - //========================================================================== + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { - // Step 11: Tile matmul using scf.forall with tile size [8, 8, 0]. - // Purpose: Introduces parallelism across M and N dimensions for multi-core execution. - // Tile sizes [8, 8, 0] create 8x8 tiles for each AIE core to process. - %matmul_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [8, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 64, + "l2-pack-sizes" = [8, 8, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 8, "outer-k-iter-index" = 2, + "core-tile" = [8, 8, 0], + "prologue-tile" = [8, 8], "epilogue-tile" = [64, 64], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op - // Step 12: Fuse pack operations into the inner parallel loop. - // Purpose: Ensures each core has its own data packing for independent execution. - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op - // Step 13: Canonicalization and CSE after tiling. - // Purpose: Cleans up IR, merges redundant ops, and prepares for further transforms. - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op + %func3a = transform.structured.match ops{["func.func"]} in %m2 - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - // Purpose: Move input data to L1, create tiled fill (prologue) and unpack (epilogue). - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 14: Promote input operands (A and B tiles) to L1 local memory. - // Purpose: Allocates L1 buffers for fast access during computation. - // memory_space = 2 corresponds to L1 (AIE local memory). - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.apply_registered_pass "canonicalize" to %func3a - // Step 15: Create tiled prologue (fill operation). - // Purpose: Initializes output buffers in parallel across cores. - // Generalize fill to generic, interchange dimensions, then tile with forall. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [8, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 16: Create tiled epilogue (unpack operation). - // Purpose: Unpacks and writes results back to L2 in parallel across cores. - // Tile sizes [64, 64] match the L2 tile dimensions. - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m2 - // Step 17: Canonicalization and CSE after buffer promotion. - // Purpose: Merges redundant allocs/copies and simplifies the IR. - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 7: BUFFERIZATION AND AIR CLEANUP - // Purpose: Convert tensors to memrefs and optimize memory operations. - //========================================================================== + transform.apply_registered_pass "cse" to %func3b - // Step 18: One-shot bufferization of the function. - // Purpose: Converts all remaining tensors to memrefs for hardware execution. - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 19: AIR-specific cleanup and memory optimization. - // Purpose: Removes uninitialized copies and eliminates redundant cascade memcpy patterns. - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op + %func3c = transform.structured.match ops{["func.func"]} in %m2 - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - // Purpose: Fuse L3->L2 copy loops with main compute loop for double buffering. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 20: Fuse L3->L2 copy loops with the main K-reduction loop. - // Purpose: Expose L2 pingpong buffering opportunity by interleaving L3->L2 data transfer with L2->L1. - // Use annotation-based matching instead of fragile split_handle. - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op // Fold affine apply into for loop bound - transform.apply_cse to %func_op_updated_1 : !transform.any_op // Ensure loop bounds use shared cst ssa values - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - // Purpose: Final tiling to enable efficient vectorized execution on AIE vector units. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 21: Tile linalg.generic (matmul) for vectorization. - // Purpose: Creates inner loops with sizes suitable for vector register usage. - // Tile sizes [2, 2, 1, 0, 0, 0] unroll M and N by 2 for register blocking. - // Use annotation-based matching instead of fragile split_handle. - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %m3 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "i32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0] + } to %m2 : (!transform.any_op) -> !transform.any_op - // Step 22: Further tile and unroll innermost loops for full vectorization. - // Purpose: Completely unrolls the innermost M and N loops for register allocation. - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op + %func4a = transform.structured.match ops{["func.func"]} in %m3 - // Step 23: Tile linalg.generic (fill) for vectorized initialization. - // Purpose: Creates vector-sized tiles for efficient zero-initialization. - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - // Purpose: Map parallel loops to AIE cores (herds) and apply vectorization. - //========================================================================== + transform.apply_registered_pass "canonicalize" to %func4a - // Step 24: Convert scf.forall loops to AIE herd operations. - // Purpose: Maps parallel work to the 8x4 AIE core array. - // Each forall becomes an air.herd representing multi-core execution. - // Use annotation-based matching instead of fragile split_handle. - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 25: Apply vectorization to AIE herds. - // Purpose: Converts scalar operations to vector operations for AIE vector units. - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op + %func4b = transform.structured.match ops{["func.func"]} in %m3 - // Step 26: Canonicalization after vectorization. - // Purpose: Simplifies vector operations and folds unit extent dimensions. - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 27: Eliminate redundant vector.transfer_read operations. - // Purpose: Removes duplicate memory reads for better performance. - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - // Purpose: Move vector reads/writes out of innermost loops for register reuse. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 28: Match the compute herd and prepare for hoisting optimization. - // Purpose: Identifies the compute herd and its vector operations for register optimization. - // Use annotation-based matching instead of fragile split_handle. - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m3 - // Step 29: Identify the innermost loop for hoisting. - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 31: Cast vector types for correct accumulation precision. - // Purpose: Ensures vector.contract uses INT32 for accumulation (INT8 inputs -> INT32 output). - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = i32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - - // Step 32: Hoist all accumulator transfer pairs from innermost loop. - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 33: Flatten loop iteration arguments and hoist vector transfer pointers. - // Purpose: Simplifies loop structure and moves pointer computations out of loops. - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c - // Step 34: Final canonicalization pass. - // Purpose: Cleans up the final IR for AIR/AIE lowering. - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/run.py b/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/run.py index b6904605a..fa854fe1c 100644 --- a/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/run.py +++ b/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/run.py @@ -89,7 +89,8 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) - # Load the MLIR transform IR from an external file + # Drive matmul codegen via the transform script (delegates to the C++ + # air-matmul-codegen orchestrator via transform.apply_registered_pass). with open(args.transform_script, "r") as f: transform_ir_string = f.read() transform_ir = Module.parse(transform_ir_string) @@ -119,6 +120,12 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) + import os + + if os.environ.get("AIR_DUMP_FINAL_IR"): + with open(os.environ["AIR_DUMP_FINAL_IR"], "w") as f: + f.write(str(air_module)) + ############################################### # Run compile and load ############################################### diff --git a/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/transform_aie2p.mlir b/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/transform_aie2p.mlir index cb0b1d613..fcb6aa480 100644 --- a/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/transform_aie2p.mlir +++ b/test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/transform_aie2p.mlir @@ -1,369 +1,96 @@ // Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. // SPDX-License-Identifier: MIT - -//////////////////////////////////////////////////////////////////////////////// -// Transform Script for Matmul with BF16 Output (Triton Ver4, Vectorized) -// -// This script transforms a matmul IR into a tiled, packed, bufferized, and -// hardware-friendly form suitable for AIE execution. -// -// Target configuration: 8x4 AIE core array (Strix) -// Data types: BF16 inputs, F32 accumulation, BF16 output // -// Memory Hierarchy: -// L3 (DDR) -> L2 (Shared Memory, memory_space=1) -> L1 (AIE Local, memory_space=2) -//////////////////////////////////////////////////////////////////////////////// +// AIE2P (Strix) single-pack bf16-out matmul codegen via the C++ +// air-matmul-codegen orchestrator. mmul=8x8x8, 256x256x256 launch. module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3->L2 MEMORY COPIES - // Convert memref.copy to linalg.copy and tile for streaming data movement. - //========================================================================== - - // Step 1: Convert memref.copy ops to linalg.copy and tile them. - // This transforms the A and B matrix copies from L3 to L2 into tileable loops. - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: FUSE TRUNCF AND PREPARE MATMUL - // Fuse the output truncation into matmul and promote output buffer to L2. - //========================================================================== - - // Step 2: Match the fill and matmul operations. - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - // Step 3: Fuse the truncf linalg.generic into the matmul. - // This produces BF16 output directly from the F32 accumulation. - %matmul_to_fuse = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %truncf_generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %fused_generic = transform.air.fuse_truncf_linalg %truncf_generic, %matmul_to_fuse : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_matmul = transform.structured.specialize %fused_generic : (!transform.any_op) -> !transform.any_op - - // Step 4: Promote the result buffer (C matrix) to L2 shared memory. - // memory_space = 1 corresponds to L2 (shared memory). - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Apply data tiling (packing) to enable efficient vectorized computation. - //========================================================================== - - // Step 5: Pack matmul with tile sizes [8, 8, 8] for M, N, K dimensions. - // This transforms linalg.matmul into linalg.generic with packed layout - // optimized for AIE vector unit utilization. - %packed = transform.structured.pack %fused_matmul packed_sizes = [8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - // Step 6: Transpose A matrix pack for correct memory layout. - // Outer permutation [1, 0] swaps the outer tile dimensions. - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 7: Transpose B matrix pack for correct memory layout. - // Both outer_perm and inner_perm [1, 0] transpose outer and inner tile dimensions. - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 8: Transpose C matrix pack/unpack for correct memory layout. - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - // Step 9: Promote the output pack operation to L1 local memory. - // memory_space = 2 corresponds to L1 (AIE local memory). - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 4: TILE REDUCTION AND FUSE PACK OPERATIONS - // Tile the K dimension and fuse data movement into compute loops. - //========================================================================== - - // Step 10: Tile the reduction (K) dimension with factor 8. - // This enables streaming of A and B tiles along the K dimension. - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - // Step 11: Fuse pack operations for A and B into the outer K-loop. - // This moves data packing inside the loop for better locality and pipelining. - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // Create parallel loops for mapping to 8x4 AIE core array. - //========================================================================== - - // Step 12: Tile matmul using scf.forall with tile sizes [8, 8, 0]. - // This introduces parallelism across M and N dimensions for multi-core execution. - %matmul_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [8, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op - - // Step 13: Fuse pack operations into the inner parallel loop. - // This ensures each core has its own data packing for independent execution. - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 14: Canonicalization and CSE after tiling. - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op - - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - // Move input data to L1, create tiled fill (prologue) and unpack (epilogue). - //========================================================================== + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { - // Step 15: Promote input operands (A and B tiles) to L1 local memory. - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, "fuse-output-truncf-first" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 64, + "l2-pack-sizes" = [8, 8, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 8, "outer-k-iter-index" = 2, + "core-tile" = [8, 8, 0], + "prologue-tile" = [8, 8], "epilogue-tile" = [64, 64], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op - // Step 16: Create tiled prologue (fill operation). - // Generalize fill to generic, interchange dimensions, then tile with forall. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [8, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op - // Step 17: Create tiled epilogue (unpack operation). - // Tile sizes [64, 64] match the L2 tile dimensions. - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op + %func3a = transform.structured.match ops{["func.func"]} in %m2 - // Step 18: Canonicalization and CSE after buffer promotion. - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 7: BUFFERIZATION AND MEMORY OPTIMIZATION - // Convert tensors to memrefs and optimize memory operations. - //========================================================================== + transform.apply_registered_pass "canonicalize" to %func3a - // Step 19: One-shot bufferization of the function. - // Converts all remaining tensors to memrefs for hardware execution. - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 20: AIR-specific cleanup and memory optimization. - // Removes uninitialized copies and eliminates redundant cascade memcpy patterns. - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m2 - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - // Fuse L3->L2 copy loops with main compute loop for double buffering. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 21: Fuse L3->L2 copy loops with the main K-reduction loop. - // This exposes L2 pingpong buffering opportunity by interleaving data transfer. - // Use annotation-based matching instead of fragile split_handle. - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op - transform.apply_cse to %func_op_updated_1 : !transform.any_op - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func3b - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - // Final tiling to enable efficient vectorized execution on AIE vector units. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 22: Tile linalg.generic (matmul) for vectorization. - // Tile sizes [2, 2, 1, 0, 0, 0] create register blocking for M and N. - // Use annotation-based matching instead of fragile split_handle. - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %func3c = transform.structured.match ops{["func.func"]} in %m2 - // Step 23: Further tile and unroll innermost loops for full vectorization. - // Completely unrolls the innermost M and N loops for register allocation. - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 24: Tile linalg.generic (fill) for vectorized initialization. - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - // Map parallel loops to AIE cores (herds) and apply vectorization. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 25: Convert scf.forall loops to AIE herd operations. - // Each forall becomes an air.herd representing multi-core execution. - // Use annotation-based matching instead of fragile split_handle. - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op + %m3 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "f32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0], + "vec-prep-hoist-cast-pairs" = true + } to %m2 : (!transform.any_op) -> !transform.any_op - // Step 26: Apply vectorization to AIE herds. - // Converts scalar operations to vector operations for AIE vector units. - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op + %func4a = transform.structured.match ops{["func.func"]} in %m3 - // Step 27: Canonicalization after vectorization. - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 28: Eliminate redundant vector.transfer_read operations. - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %func4a - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - // Move vector reads/writes out of innermost loops for register reuse. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 29: Identify the matmul compute herd and innermost K-loop. - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %func4b = transform.structured.match ops{["func.func"]} in %m3 - // Step 30: Cast vector types for correct accumulation precision. - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - // Step 33: Hoist all accumulator transfer pairs from innermost K-loop. - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b - //========================================================================== - // PHASE 12: HOIST EXTF/TRUNCF CAST PAIRS FOR BF16 OUTPUT - // Move BF16<->F32 conversions out of innermost loop for efficiency. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 34: Match extf/truncf operations in the innermost loop. - %fors_to_hoist_ptrs = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for1, %outer_fors1 = transform.split_handle %fors_to_hoist_ptrs {overflow_result = 1}: (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %all_extf_loop = transform.structured.match ops{["arith.extf"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop = transform.structured.match ops{["arith.truncf"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - %extf_bf16_1, %extf_bf16_2, %extf_bf16_3, %extf_bf16_4 = transform.split_handle %all_extf_loop : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %truncf_1, %truncf_2, %truncf_3, %truncf_4 = transform.split_handle %all_truncf_loop : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_1 = transform.air.hoist_cast_pair %extf_bf16_1, %truncf_1, %innermost_for1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - %all_extf_loop_2 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_2 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %extf_bf16_2_new, %e2_5, %e2_6 = transform.split_handle %all_extf_loop_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %truncf_2_1, %truncf_2_2, %truncf_2_3 = transform.split_handle %all_truncf_loop_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_2 = transform.air.hoist_cast_pair %extf_bf16_2_new, %truncf_2_1, %for1_1_hoisted_1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - %all_extf_loop_3 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_3 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %extf_bf16_3_new, %e3_7 = transform.split_handle %all_extf_loop_3 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %truncf_3_1, %truncf_3_2 = transform.split_handle %all_truncf_loop_3 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %for1_1_hoisted_3 = transform.air.hoist_cast_pair %extf_bf16_3_new, %truncf_3_1, %for1_1_hoisted_2 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - %all_extf_loop_4 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_4 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %for1_1_hoisted_final = transform.air.hoist_cast_pair %all_extf_loop_4, %all_truncf_loop_4, %for1_1_hoisted_3 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m3 - //========================================================================== - // PHASE 13: FINAL LOOP OPTIMIZATIONS - // Flatten iteration arguments and hoist pointer computations. - //========================================================================== + : (!transform.any_op) -> !transform.any_op - // Step 36: Flatten loop iteration arguments. - // Simplifies the loop structure by flattening iter_args. - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %for1_1_hoisted_final : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c - // Step 37: Final canonicalization pass. - // Cleans up the final IR for AIR/AIE lowering. - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/test/xrt/53_matmul_padding_bf16/run.py b/test/xrt/53_matmul_padding_bf16/run.py index 8634ef1b6..06535f5b1 100644 --- a/test/xrt/53_matmul_padding_bf16/run.py +++ b/test/xrt/53_matmul_padding_bf16/run.py @@ -38,6 +38,11 @@ help="Transform script path", ) parser.add_argument("-v", "--verbose", action="store_true") +parser.add_argument( + "--print-module-only", + action="store_true", + help="Print module after air-copy-to-dma and exit (debug aid).", +) parser.add_argument( "--compile-mode", type=str, @@ -173,26 +178,24 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) + # Drive matmul codegen via the transform script (delegates to the C++ + # air-matmul-codegen orchestrator via transform.apply_registered_pass). + # Defaults assume --k-l2-tile=16; rewrite k-l2-tile / outer-k-tile-factor + # in the script when the user picks a different value. with open(args.transform_script, "r") as f: transform_ir_string = f.read() - # Parametrize L2 K-tile size in the transform script. - if K_L2_TILE != 64: + if K_L2_TILE != 16: import re transform_ir_string = re.sub( - r"(tile_using_for %copy1 tile_sizes \[0, )64(\])", - rf"\g<1>{K_L2_TILE}\2", - transform_ir_string, - ) - transform_ir_string = re.sub( - r"(tile_using_for %copy2 tile_sizes \[)64(\])", - rf"\g<1>{K_L2_TILE}\2", + r'("k-l2-tile" = )16(\b)', + rf"\g<1>{K_L2_TILE}\g<2>", transform_ir_string, ) - k_red_tile = K_L2_TILE // 8 + k_factor = max(1, K_L2_TILE // 8) transform_ir_string = re.sub( - r"(tile_using_for %packed_c tile_sizes \[0, 0, )8(\])", - rf"\g<1>{k_red_tile}\2", + r'("outer-k-tile-factor" = )2(\b)', + rf"\g<1>{k_factor}\g<2>", transform_ir_string, ) transform_ir = Module.parse(transform_ir_string) @@ -218,6 +221,10 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) + if args.print_module_only: + print(air_module) + exit(0) + ############################################### # Compile and run ############################################### diff --git a/test/xrt/53_matmul_padding_bf16/transform_aie2p.mlir b/test/xrt/53_matmul_padding_bf16/transform_aie2p.mlir index 827247ac7..c4d9d2480 100644 --- a/test/xrt/53_matmul_padding_bf16/transform_aie2p.mlir +++ b/test/xrt/53_matmul_padding_bf16/transform_aie2p.mlir @@ -1,302 +1,98 @@ -// Transform Script for 128x256 Matmul with BF16 Output (Triton Ver4, Vectorized) +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT // -// Adapted from test/xrt/48_triton_matmul_ver4_strix_4x4_bf16_output/transform_aie2p.mlir -// for 128x256 output tile (M_TILE=128, N_TILE=256, K=256). -// -// Target configuration: 4x2 AIE core array (Strix) -// Data types: BF16 inputs, F32 accumulation, BF16 output -// -// After packing [8,8,8] with C outer_perm [1,0]: -// packed shape = [N/8, M/8, K/8, 8, 8, 8] = [32, 16, 32, 8, 8, 8] -// Phase 5 forall [8, 8, 0] → herd 4x2 = 8 cores -// -// Memory Hierarchy: -// L3 (DDR) -> L2 (Shared Memory, memory_space=1) -> L1 (AIE Local, memory_space=2) +// AIE2P (Strix) bf16-out matmul codegen via the C++ air-matmul-codegen +// orchestrator with non-tile-aligned M, N (padding via memtile DMA). +// Defaults match --k-l2-tile=16; run.py rewrites the k-l2-tile and +// outer-k-tile-factor values when --k-l2-tile differs. module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3->L2 MEMORY COPIES - // Convert memref.copy to linalg.copy and tile for streaming data movement. - //========================================================================== - - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: FUSE TRUNCF AND PREPARE MATMUL - // Fuse the output truncation into matmul and promote output buffer to L2. - //========================================================================== - - %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - - %matmul_to_fuse = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %truncf_generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %fused_generic = transform.air.fuse_truncf_linalg %truncf_generic, %matmul_to_fuse : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_matmul = transform.structured.specialize %fused_generic : (!transform.any_op) -> !transform.any_op - - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Pack sizes [8, 8, 8] for M, N, K dimensions. - //========================================================================== - - %packed = transform.structured.pack %fused_matmul packed_sizes = [8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 4: TILE REDUCTION AND FUSE PACK OPERATIONS - // Tile K dimension with factor 8 and fuse packs into K-loop. - //========================================================================== - - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // For 128x256 tile: packed dims [32, 16], forall [8, 8, 0] → 4x2 herd = 8 cores. - //========================================================================== - - %matmul_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [8, 8, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op - - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op - - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - //========================================================================== + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, "fuse-output-truncf-first" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 16, + "l2-pack-sizes" = [8, 8, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 2, "outer-k-iter-index" = 2, + "core-tile" = [8, 8, 0], + "prologue-tile" = [8, 8], "epilogue-tile" = [64, 64], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op - // Prologue: fill → generalize → interchange → tile_using_forall - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [8, 8] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op - // Epilogue: unpack → tile_using_forall [64, 64] - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 64] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op + %func3a = transform.structured.match ops{["func.func"]} in %m2 - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op + : (!transform.any_op) -> !transform.any_op - //========================================================================== - // PHASE 7: BUFFERIZATION AND MEMORY OPTIMIZATION - //========================================================================== + transform.apply_registered_pass "canonicalize" to %func3a - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m2 - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - //========================================================================== + : (!transform.any_op) -> !transform.any_op - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op - transform.apply_cse to %func_op_updated_1 : !transform.any_op - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func3b - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - //========================================================================== + : (!transform.any_op) -> !transform.any_op - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %func3c = transform.structured.match ops{["func.func"]} in %m2 - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op + : (!transform.any_op) -> !transform.any_op - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - //========================================================================== + : (!transform.any_op) -> !transform.any_op - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op + %m3 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "f32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0], + "vec-prep-hoist-cast-pairs" = true + } to %m2 : (!transform.any_op) -> !transform.any_op - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op + %func4a = transform.structured.match ops{["func.func"]} in %m3 - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %func4a - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - //========================================================================== + : (!transform.any_op) -> !transform.any_op - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %func4b = transform.structured.match ops{["func.func"]} in %m3 - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b - //========================================================================== - // PHASE 12: HOIST EXTF/TRUNCF CAST PAIRS FOR BF16 OUTPUT - //========================================================================== + : (!transform.any_op) -> !transform.any_op - %fors_to_hoist_ptrs = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for1, %outer_fors1 = transform.split_handle %fors_to_hoist_ptrs {overflow_result = 1}: (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %all_extf_loop = transform.structured.match ops{["arith.extf"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop = transform.structured.match ops{["arith.truncf"]} in %innermost_for1 : (!transform.any_op) -> !transform.any_op - %extf_bf16_1, %extf_bf16_2, %extf_bf16_3, %extf_bf16_4 = transform.split_handle %all_extf_loop : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %truncf_1, %truncf_2, %truncf_3, %truncf_4 = transform.split_handle %all_truncf_loop : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_1 = transform.air.hoist_cast_pair %extf_bf16_1, %truncf_1, %innermost_for1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - %all_extf_loop_2 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_2 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_1 : (!transform.any_op) -> !transform.any_op - %extf_bf16_2_new, %e2_5, %e2_6 = transform.split_handle %all_extf_loop_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %truncf_2_1, %truncf_2_2, %truncf_2_3 = transform.split_handle %all_truncf_loop_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %for1_1_hoisted_2 = transform.air.hoist_cast_pair %extf_bf16_2_new, %truncf_2_1, %for1_1_hoisted_1 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - %all_extf_loop_3 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_3 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_2 : (!transform.any_op) -> !transform.any_op - %extf_bf16_3_new, %e3_7 = transform.split_handle %all_extf_loop_3 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %truncf_3_1, %truncf_3_2 = transform.split_handle %all_truncf_loop_3 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %for1_1_hoisted_3 = transform.air.hoist_cast_pair %extf_bf16_3_new, %truncf_3_1, %for1_1_hoisted_2 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op - %all_extf_loop_4 = transform.structured.match ops{["arith.extf"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %all_truncf_loop_4 = transform.structured.match ops{["arith.truncf"]} in %for1_1_hoisted_3 : (!transform.any_op) -> !transform.any_op - %for1_1_hoisted_final = transform.air.hoist_cast_pair %all_extf_loop_4, %all_truncf_loop_4, %for1_1_hoisted_3 : (!transform.any_op, !transform.any_op, !transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m3 - //========================================================================== - // PHASE 13: FINAL LOOP OPTIMIZATIONS - //========================================================================== + : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %for1_1_hoisted_final : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/test/xrt/54_matmul_padding_f32_bf16_emulation/run.py b/test/xrt/54_matmul_padding_f32_bf16_emulation/run.py index 9ae0e65c8..e00e6574e 100644 --- a/test/xrt/54_matmul_padding_f32_bf16_emulation/run.py +++ b/test/xrt/54_matmul_padding_f32_bf16_emulation/run.py @@ -163,9 +163,37 @@ pm = air.passmanager.PassManager.parse(pipeline) pm.run(air_module.operation) - # Apply transform script + # Drive matmul codegen via the transform script (delegates to the C++ + # air-matmul-codegen orchestrator via transform.apply_registered_pass). + # Defaults assume k-l2-tile=16 / herd=4x4 / TILE_M=64 / TILE_N=32 -> + # LT_M=256, LT_N=128, epilogue=64x32. Rewrite k-l2-tile + + # outer-k-tile-factor + epilogue-tile when those derived values differ. with open(transform_path, "r") as f: transform_ir_string = f.read() + epM = max(4 * 8, LT_M // HERD_M) + epN = max(1, LT_N // HERD_N) + if K_L2_TILE != 16: + import re + + transform_ir_string = re.sub( + r'("k-l2-tile" = )16(\b)', + rf"\g<1>{K_L2_TILE}\g<2>", + transform_ir_string, + ) + k_factor = max(1, K_L2_TILE // 8) + transform_ir_string = re.sub( + r'("outer-k-tile-factor" = )2(\b)', + rf"\g<1>{k_factor}\g<2>", + transform_ir_string, + ) + if (epM, epN) != (64, 32): + import re + + transform_ir_string = re.sub( + r'("epilogue-tile" = )\[64, 32\]', + rf"\g<1>[{epM}, {epN}]", + transform_ir_string, + ) transform_ir = Module.parse(transform_ir_string, context=air_module.context) run_transform(transform_ir, air_module) diff --git a/test/xrt/54_matmul_padding_f32_bf16_emulation/transform_aie2p.mlir b/test/xrt/54_matmul_padding_f32_bf16_emulation/transform_aie2p.mlir index 7435413ef..093329240 100644 --- a/test/xrt/54_matmul_padding_f32_bf16_emulation/transform_aie2p.mlir +++ b/test/xrt/54_matmul_padding_f32_bf16_emulation/transform_aie2p.mlir @@ -1,281 +1,80 @@ -// Transform Script for F32 Matmul with BF16 Emulation +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT // -// Starting IR: Full-K matmul (no K-loop), all f32, generated from asm_src params. -// - func @matmul_padding_kernel(memref<*xf32>*3, i32*6) -// - linalg.matmul(64xK @ Kx32 → 64x32), f32 accumulation -// - A in K×M layout (strides [1, M_alloc]), B in K×N (strides [N_alloc, 1]) -// -// Follows test 53's transform pattern: tile copies, pack [8,8,8], tile K, -// tile forall for multi-core, vectorize, hoist. -// -// Target: 4×8 AIE core array (Strix/NPU2), BFP16 emulation -// Tile sizes: M=64, N=32, K_L2=16, pack [8,8,8] +// AIE2P (Strix) f32-in/out matmul codegen with BFP16 emulation, via the +// C++ air-matmul-codegen orchestrator. Defaults match the Makefile's +// k-l2-tile=16 / herd=4x4 / TILE_M=64 / TILE_N=32 (LT=256x128); run.py +// rewrites k-l2-tile + outer-k-tile-factor + epilogue-tile when those +// derived values differ. module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - - //========================================================================== - // PHASE 1: TILE L3→L2 MEMORY COPIES - //========================================================================== - - %func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op - %copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - // Tile A copy: 64×K → 64×16 tiles (K_L2_TILE=16) - %tiled_copy1, %tile_copy_loop1 = - transform.structured.tile_using_for %copy1 tile_sizes [0, 16] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op - // Tile B copy: K×32 → 16×32 tiles - %tiled_copy2, %tile_copy_loop2 = - transform.structured.tile_using_for %copy2 tile_sizes [16] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op - - //========================================================================== - // PHASE 2: PROMOTE OUTPUT TO L2 - // No truncf fusion needed (output is f32). - //========================================================================== - - %result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result_l2_buffer, %result_t2_new = transform.structured.bufferize_to_allocation %result_l2 - {memory_space = 1, bufferize_destination_only, mempcy = "linalg.copy", emit_dealloc} : !transform.any_op - - //========================================================================== - // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION - // Pack sizes [8, 8, 8] for M, N, K dimensions. - //========================================================================== - - %matmul_to_pack = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %packed = transform.structured.pack %matmul_to_pack packed_sizes = [8, 8, 8] - : (!transform.any_op) -> (!transform.any_op) - - %pack_producer_a = transform.get_producer_of_operand %packed[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_a, %pack_a, %empty_unpack_a = - transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %pack_producer_b = transform.get_producer_of_operand %packed_a[1] - : (!transform.any_op) -> (!transform.any_op) - %packed_b, %pack_b, %empty_unpack_b = - transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a) - outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %unpack = transform.get_consumers_of_result %packed_b[0] - : (!transform.any_op) -> (!transform.any_op) - %packed_c, %pack_c, %unpack_c = - transform.structured.pack_transpose %unpack with_compute_op(%packed_b) - outer_perm = [1, 0] : (!transform.any_op, !transform.any_op) - -> (!transform.any_op, !transform.any_op, !transform.any_op) - - %output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c - {memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op - - // Annotate the packed matmul so we can find it after K-tiling - transform.annotate %packed_c "packed_matmul" : !transform.any_op - - //========================================================================== - // PHASE 4: TILE K REDUCTION AND FUSE PACK OPERATIONS - // K/8 packed K-dim. Tile by 2 (= 16 raw K elements = K_L2_TILE). - //========================================================================== - - %tiled_reduction, %outer_for_loop = - transform.structured.tile_using_for %packed_c tile_sizes [0, 0, 2] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %outer_for_loop "k_reduction_loop" : !transform.any_op - - %fused_lhs_l1_pack, %2 = transform.structured.fuse_into_containing_op %pack_a into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack, %3 = transform.structured.fuse_into_containing_op %pack_b into %outer_for_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 5: TILE FOR MULTI-CORE PARALLELISM - // Packed C dims after pack [8,8,8] + outer_perm [1,0]: - // [N/8, M/8, K/8] = [16, 32, K/8] → tile [8, 4, 0] → forall(2, 8) - // par_to_herd maps to herd(8, 2) → collapse to 4×4 - //========================================================================== - - %matmul_1 = transform.structured.match ops{["linalg.generic"]} attributes{packed_matmul} in %arg1 : (!transform.any_op) -> !transform.any_op - %tiled_matmul_1, %inner_forall = - transform.structured.tile_using_forall %matmul_1 tile_sizes [8, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %inner_forall "compute_forall" : !transform.any_op - transform.annotate %tiled_matmul_1 "matmul_compute" : !transform.any_op - - %fused_lhs_l1_pack2, %6 = transform.structured.fuse_into_containing_op %fused_lhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused_rhs_l1_pack2, %7 = transform.structured.fuse_into_containing_op %fused_rhs_l1_pack into %inner_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op - - //========================================================================== - // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE - //========================================================================== - - %buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - %buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2 - {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op - - // Prologue: fill → generalize → interchange → tile_using_forall - // After packing, fill is on packed 4D tensor [N/8, M/8, 8, 8] = [16, 32, 8, 8]. - // Interchange [1,0,2,3] swaps N/M dims → [32, 16, 8, 8]. - // Tile [8, 4] → forall(4, 4) matching herd. - %fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic_fill_op = transform.structured.generalize %fill_op - : (!transform.any_op) -> !transform.any_op - transform.annotate %generic_fill_op "init_fill" : !transform.any_op - %interchanged_fill_op = transform.structured.interchange %generic_fill_op - iterator_interchange = [1, 0, 2, 3] - : (!transform.any_op) -> !transform.any_op - %prologue_tiled_fill, %prologue_forall = - transform.structured.tile_using_forall %interchanged_fill_op tile_sizes [8, 4] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %prologue_forall "prologue_forall" : !transform.any_op - - // Epilogue: unpack → tile_using_forall [64, 32] for 4×4 herd - %unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %epilogue_tiled_unpack, %epilogue_forall = - transform.structured.tile_using_forall %unpack_op tile_sizes [64, 32] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.annotate %epilogue_forall "epilogue_forall" : !transform.any_op - - %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_3 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_3 : !transform.any_op - - //========================================================================== - // PHASE 7: BUFFERIZATION AND MEMORY OPTIMIZATION - //========================================================================== - - %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op - - %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func6 : !transform.any_op - transform.apply_patterns to %func6 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op - %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 8: FUSE LOOPS FOR L2 PINGPONG BUFFERING - //========================================================================== - - %for_loop_copy_1 = transform.structured.match ops{["scf.for"]} attributes{copy_a_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %for_loop_copy_2 = transform.structured.match ops{["scf.for"]} attributes{copy_b_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop = transform.structured.match ops{["scf.for"]} attributes{k_reduction_loop} in %arg1 : (!transform.any_op) -> !transform.any_op - %main_for_loop_norm = transform.air.normalize_for_bounds %main_for_loop : (!transform.any_op) -> !transform.any_op - transform.apply_cse to %func_op_updated_1 : !transform.any_op - %fused_for_loop_2 = transform.loop.fuse_sibling %for_loop_copy_2 into %main_for_loop_norm - : (!transform.any_op, !transform.any_op) -> !transform.any_op - %fused_for_loop_1 = transform.loop.fuse_sibling %for_loop_copy_1 into %fused_for_loop_2 - : (!transform.any_op, !transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 9: TILE FOR VECTORIZATION - //========================================================================== - - %generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op - %generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op - // Per-core packed matmul: [4, 8, K/8, 8, 8, 8]. - // Tile for vectorization: [2, 2, 1, 0, 0, 0] then unroll. - %inner_most_generics, %vec_loops:3 = - transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - %inner_most_matmul_to_unroll, %vec_loops_to_unroll:2 = - transform.structured.tile_using_for %inner_most_generics tile_sizes [1, 1, 0, 0, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.loop.unroll %vec_loops_to_unroll#1 {factor = 2} : !transform.any_op - transform.loop.unroll %vec_loops_to_unroll#0 {factor = 2} : !transform.any_op - - %inner_most_fills, %vec_fill_loops:2 = - transform.structured.tile_using_for %generic1 tile_sizes [1, 1, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - - //========================================================================== - // PHASE 10: CONVERT TO AIE HERDS AND VECTORIZE - //========================================================================== - - %forall1 = transform.structured.match ops{["scf.forall"]} attributes{prologue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall2 = transform.structured.match ops{["scf.forall"]} attributes{compute_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %forall3 = transform.structured.match ops{["scf.forall"]} attributes{epilogue_forall} in %arg1 : (!transform.any_op) -> !transform.any_op - %parallel1 = transform.loop.forall_to_parallel %forall1 : (!transform.any_op) -> !transform.any_op - %herd1 = transform.air.par_to_herd %parallel1 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd1 "prologue_herd" : !transform.any_op - %parallel2 = transform.loop.forall_to_parallel %forall2 : (!transform.any_op) -> !transform.any_op - %herd2 = transform.air.par_to_herd %parallel2 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd2 "compute_herd" : !transform.any_op - %parallel3 = transform.loop.forall_to_parallel %forall3 : (!transform.any_op) -> !transform.any_op - %herd3 = transform.air.par_to_herd %parallel3 : (!transform.any_op) -> !transform.any_op - transform.annotate %herd3 "epilogue_herd" : !transform.any_op - - %herds = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %vectorized_herds = transform.air.herd_vectorize %herds : (!transform.any_op) -> !transform.any_op - - %func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func7 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_1 = transform.air.fold_unit_extent_dims %func_fold_1 : (!transform.any_op) -> !transform.any_op - - %func7_rematch = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func1_optimized = transform.air.eliminate_redundant_vector_transfers %func7_rematch : (!transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 11: HOIST LOOP-INVARIANT VECTOR TRANSFERS - //========================================================================== - - %herd2_1 = transform.structured.match ops{["air.herd"]} attributes{compute_herd} in %arg1 : (!transform.any_op) -> !transform.any_op - %scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op - %innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Cast vector.contract input types: inputs 0,1 to bf16, accumulator 2 and output to f32 - %vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op - %vector_contracts_2 = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %result11b = transform.air.vector_type_cast %vector_contracts_2 {target_element_type = bf16, input_indices = [0, 1], output_indices = []} : (!transform.any_op) -> !transform.any_op - - %innermost_for_updated_3 = transform.air.hoist_loop_invariant_transfers %herd2_1, %innermost_for : (!transform.any_op, !transform.any_op) -> !transform.any_op - - //========================================================================== - // PHASE 12: FINAL LOOP OPTIMIZATIONS - //========================================================================== - - %innermost_for_updated_4 = transform.air.flatten_for_iter_args %innermost_for_updated_3 : (!transform.any_op) -> !transform.any_op - %innermost_for_updated_5 = transform.air.hoist_vector_transfer_pointers %innermost_for_updated_4 : (!transform.any_op) -> !transform.any_op - - %func9 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func9 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - transform.apply_patterns.memref.fold_memref_alias_ops - } : !transform.any_op - %func_fold_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %func_folded_2 = transform.air.fold_unit_extent_dims %func_fold_2 : (!transform.any_op) -> !transform.any_op + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { + + %m1 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "bufferize-output-l2" = true, + "tile-l3-to-l2-copies" = true, "k-l2-tile" = 16, + "l2-pack-sizes" = [8, 8, 8], + "l2-lhs-outer-perm" = [1, 0], "l2-lhs-inner-perm" = [0, 1], + "l2-rhs-outer-perm" = [1, 0], "l2-rhs-inner-perm" = [1, 0], + "l2-acc-outer-perm" = [1, 0], "l2-acc-inner-perm" = [0, 1], + "outer-k-tile-factor" = 2, "outer-k-iter-index" = 2, + "core-tile" = [8, 4, 0], + "prologue-tile" = [8, 4], "epilogue-tile" = [64, 32], + "fill-iter-perm" = [1, 0, 2, 3], + "one-shot-bufferize" = true, + "post-bufferize-cleanup-first" = true, + "matmul-vec-tile" = [2, 2, 1, 0, 0, 0], + "matmul-unroll-vec-tile" = [1, 1, 0, 0, 0, 0], + "matmul-unroll-factor" = 2, + "fill-vec-tile" = [1, 1, 0, 0] + } to %arg1 : (!transform.any_op) -> !transform.any_op + + %func1 = transform.structured.match ops{["func.func"]} in %m1 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "scf-forall-to-parallel" to %func1 + : (!transform.any_op) -> !transform.any_op + %m2 = transform.apply_registered_pass "air-par-to-herd" to %m1 + : (!transform.any_op) -> !transform.any_op + %func2 = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "air-herd-vectorize" to %func2 + : (!transform.any_op) -> !transform.any_op + + // Cleanup: canonicalize + cse + fold-memref-alias-ops as full passes + // (not just apply_patterns, which is one-shot and doesn't iterate). + %func3a = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %func3a + : (!transform.any_op) -> !transform.any_op + %func3b = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func3b + : (!transform.any_op) -> !transform.any_op + %func3c = transform.structured.match ops{["func.func"]} in %m2 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func3c + : (!transform.any_op) -> !transform.any_op + + %m3 = transform.apply_registered_pass "air-matmul-codegen" with options = { + "vec-prep-cast1-target-element-type" = "f32", + "vec-prep-cast1-input-indices" = [2], + "vec-prep-cast1-output-indices" = [0], + "vec-prep-cast2-target-element-type" = "bf16", + "vec-prep-cast2-input-indices" = [0, 1] + } to %m2 : (!transform.any_op) -> !transform.any_op + + %func4a = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %func4a + : (!transform.any_op) -> !transform.any_op + %func4b = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "cse" to %func4b + : (!transform.any_op) -> !transform.any_op + %func4c = transform.structured.match ops{["func.func"]} in %m3 + : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "fold-memref-alias-ops" to %func4c + : (!transform.any_op) -> !transform.any_op transform.yield }