diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Map/BUILD.bazel new file mode 100644 index 000000000000..29675991bf40 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/BUILD.bazel @@ -0,0 +1,11 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Map/CMakeLists.txt new file mode 100644 index 000000000000..72bd7d684758 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/CMakeLists.txt @@ -0,0 +1,13 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Map/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/BUILD.bazel new file mode 100644 index 000000000000..edee20ad1f07 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/BUILD.bazel @@ -0,0 +1,95 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library", "iree_td_library") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "IREEMapAttrs.td", + "IREEMapBase.td", +]) + +iree_td_library( + name = "td_files", + srcs = enforce_glob( + # keep sorted + [ + "IREEMapAttrs.td", + "IREEMapBase.td", + ], + include = ["*.td"], + ), + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +iree_compiler_cc_library( + name = "IREEMapDialect", + srcs = [ + "IREEMapAttrs.cpp", + "IREEMapDialect.cpp", + "IntTuple.cpp", + ], + hdrs = [ + "IREEMapAttrs.h", + "IREEMapDialect.h", + "IntTuple.h", + ], + textual_hdrs = [ + "IREEMapAttrs.cpp.inc", + "IREEMapAttrs.h.inc", + "IREEMapDialect.cpp.inc", + "IREEMapDialect.h.inc", + ], + deps = [ + ":IREEMapAttrsGen", + ":IREEMapDialectGen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +iree_gentbl_cc_library( + name = "IREEMapDialectGen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "IREEMapDialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "IREEMapDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "IREEMapBase.td", + deps = [":td_files"], +) + +iree_gentbl_cc_library( + name = "IREEMapAttrsGen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "IREEMapAttrs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "IREEMapAttrs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "IREEMapAttrs.td", + deps = [":td_files"], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/CMakeLists.txt new file mode 100644 index 000000000000..b19e567ba222 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/CMakeLists.txt @@ -0,0 +1,58 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Map/IR/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + IREEMapDialect + HDRS + "IREEMapAttrs.h" + "IREEMapDialect.h" + "IntTuple.h" + TEXTUAL_HDRS + "IREEMapAttrs.cpp.inc" + "IREEMapAttrs.h.inc" + "IREEMapDialect.cpp.inc" + "IREEMapDialect.h.inc" + SRCS + "IREEMapAttrs.cpp" + "IREEMapDialect.cpp" + "IntTuple.cpp" + DEPS + ::IREEMapAttrsGen + ::IREEMapDialectGen + LLVMSupport + MLIRIR + MLIRSupport + PUBLIC +) + +iree_tablegen_library( + NAME + IREEMapDialectGen + TD_FILE + "IREEMapBase.td" + OUTS + --gen-dialect-decls IREEMapDialect.h.inc + --gen-dialect-defs IREEMapDialect.cpp.inc +) + +iree_tablegen_library( + NAME + IREEMapAttrsGen + TD_FILE + "IREEMapAttrs.td" + OUTS + --gen-attrdef-decls IREEMapAttrs.h.inc + --gen-attrdef-defs IREEMapAttrs.cpp.inc +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.cpp new file mode 100644 index 000000000000..1e22aba31f99 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.cpp @@ -0,0 +1,682 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This PackMap is inspired by CuTe layouts +// (https://arxiv.org/pdf/2603.02298v1), but adapted for IREE's use case. +// +// The layout algebra implementation is derived from code: +// Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights +// reserved. with BSD-3-Clause license. +// https://github.com/pytorch/pytorch/blob/main/torch/distributed/_pycute/layout.py + +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h" + +#include "iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::Map; + +//===----------------------------------------------------------------------===// +// IntTuple parsing/printing helpers +//===----------------------------------------------------------------------===// + +/// Parse an IntTuple: +/// IntTuple ::= Integer | '(' IntTuple (',' IntTuple)* ')' +/// Uses an iterative shift-reduce approach with an explicit stack to avoid +/// recursion. +static FailureOr parseIntTupleImpl(AsmParser &parser) { + MLIRContext *ctx = parser.getContext(); + + SmallVector> stack; + Attribute result; + bool startItem = true; + + while (true) { + if (startItem) { + if (succeeded(parser.parseOptionalLParen())) { + stack.push_back({}); + continue; + } + int64_t val; + if (failed(parser.parseInteger(val))) { + return failure(); + } + result = makeLeaf(ctx, val); + startItem = false; + continue; + } + if (stack.empty()) { + return result; + } + stack.back().push_back(result); + if (succeeded(parser.parseOptionalComma())) { + startItem = true; + continue; + } + if (failed(parser.parseRParen())) { + return failure(); + } + result = makeTuple(ctx, stack.back()); + stack.pop_back(); + } +} + +static ParseResult parseIntTuple(AsmParser &parser, Attribute &result) { + FailureOr parsed = parseIntTupleImpl(parser); + if (failed(parsed)) { + return failure(); + } + result = *parsed; + return success(); +} + +/// Print an IntTuple. +static void printIntTuple(AsmPrinter &printer, Attribute attr) { + if (isLeaf(attr)) { + printer << getLeafValue(attr); + return; + } + printer << "("; + llvm::interleaveComma(cast(attr), printer, + [&](Attribute elem) { printIntTuple(printer, elem); }); + printer << ")"; +} + +LogicalResult PackMapAttr::verify(function_ref emitError, + Attribute shape, Attribute stride) { + if (!isIntTuple(shape)) { + return emitError() << "shape must be a valid IntTuple"; + } + if (!isIntTuple(stride)) { + return emitError() << "stride must be a valid IntTuple"; + } + if (!isCongruent(shape, stride)) { + return emitError() + << "shape and stride must be congruent (identical tree structure)"; + } + + for (int64_t v : getLeaves(shape)) { + if (v <= 0) { + return emitError() << "shape leaf values must be positive, got " << v; + } + } + for (int64_t v : getLeaves(stride)) { + if (v < 0) { + return emitError() << "stride leaf values must be non-negative, got " + << v; + } + } + // Overflow checking for getSize() / getCosize() is intentionally omitted: + // those values can overflow int64_t for very large layouts, which is + // documented as caller responsibility. + + return success(); +} + +//===----------------------------------------------------------------------===// +// PackMapAttr - property methods +//===----------------------------------------------------------------------===// + +int64_t PackMapAttr::getRank() { return Map::getRank(getShape()); } + +int64_t PackMapAttr::getDepth() { return Map::getDepth(getShape()); } + +int64_t PackMapAttr::getSize() { return Map::getSize(getShape()); } + +int64_t PackMapAttr::getCosize() { + int64_t maximumIndex = + foldLeafInfos(getShape(), getStride(), 0, + [](int64_t acc, const LeafInfo &leaf) -> int64_t { + return acc + (leaf.size - 1) * leaf.stride; + }); + return maximumIndex + 1; +} + +Attribute PackMapAttr::getShapeMode(int64_t i) { + return Map::getElement(getShape(), i); +} + +Attribute PackMapAttr::getStrideMode(int64_t i) { + return Map::getElement(getStride(), i); +} + +//===----------------------------------------------------------------------===// +// PackMapAttr - evaluation +//===----------------------------------------------------------------------===// + +int64_t PackMapAttr::evaluate(ArrayRef coord) { + SmallVector shapes = getLeaves(getShape()); + + SmallVector natCoord; + if (coord.size() == 1 && shapes.size() > 1) { + natCoord = idx2crd(coord[0], getShape()); + } else { + assert(coord.size() == shapes.size() && + "coordinate size must match number of leaf shapes"); + natCoord.assign(coord.begin(), coord.end()); + } + + return crd2idx(natCoord, getStride()); +} + +//===----------------------------------------------------------------------===// +// PackMapAttr - simplification +//===----------------------------------------------------------------------===// + +/// Flatten all hierarchy into a single-level tuple of leaves. +/// ((2, 4), 8) : ((16, 1), 4) -> (2, 4, 8) : (16, 1, 4). +PackMapAttr PackMapAttr::flatten() { + MLIRContext *ctx = getContext(); + return PackMapAttr::get(ctx, Map::flatten(ctx, getShape()), + Map::flatten(ctx, getStride())); +} + +/// Core merge: scan flat (shapes, strides) right-to-left and merge adjacent +/// contiguous leaves. Returns merged pairs in left-to-right order. +/// +/// Two adjacent leaves can be merged into one when they cover a contiguous +/// range of indices. Scanning right-to-left (innermost first), we accumulate +/// into (accShape, accStride) and check each new leaf (si, di): +/// +/// - si == 1: trivial, contributes nothing -- skip. +/// - accShape == 1: accumulator is trivial -- replace it with (si, di). +/// - di == accShape * accStride: the new leaf starts exactly where the +/// accumulator ends, so they are contiguous -- merge into +/// (si * accShape, accStride). +/// - otherwise: non-contiguous -- push the accumulator and start fresh. +/// +/// Example: flatten (4, (2, 4)) : (8, (4, 1)) -> leaves (4, 2, 4) : (8, 4, 1). +/// Start: acc = (4, 1). +/// (2, 4): accShape*accStride = 4*1 = 4 == 4 -> merge: acc = (8, 1). +/// (4, 8): accShape*accStride = 8*1 = 8 == 8 -> merge: acc = (32, 1). +/// Result: [(32, 1)] -> (32) : (1). +static SmallVector> +coalesceImpl(ArrayRef shapes, ArrayRef strides) { + assert(shapes.size() == strides.size()); + SmallVector> result; + if (shapes.empty()) { + return result; + } + result.push_back({shapes.back(), strides.back()}); + for (int i = static_cast(shapes.size()) - 2; i >= 0; --i) { + int64_t si = shapes[i], di = strides[i]; + auto &[accShape, accStride] = result.back(); + if (si == 1) { + continue; + } + if (accShape == 1) { + accShape = si; + accStride = di; + } else if (accShape * accStride == di) { + accShape = si * accShape; + } else { + result.push_back({si, di}); + } + } + std::reverse(result.begin(), result.end()); + return result; +} + +/// Merge adjacent leaves across all modes, including across mode boundaries. +/// +/// Example: (4, (2, 4)) : (8, (4, 1)) -> (32) : (1). +/// All three leaves (4:8, 2:4, 4:1) are contiguous and merge into one. +/// Compare: coalesceModes() on the same input gives (4, 8) : (8, 1), +/// preserving the mode boundary. +/// +/// Normalizes size-1 leaves to stride 0 to produce a canonical form: any +/// size-1 mode contributes nothing to the index, so its stride is irrelevant. +PackMapAttr PackMapAttr::coalesce() { + MLIRContext *ctx = getContext(); + SmallVector newShape, newStride; + for (auto [s, d] : + coalesceImpl(getLeaves(getShape()), getLeaves(getStride()))) { + newShape.push_back(makeLeaf(ctx, s)); + newStride.push_back(makeLeaf(ctx, s == 1 ? 0 : d)); + } + return PackMapAttr::get(ctx, makeTuple(ctx, newShape), + makeTuple(ctx, newStride)); +} + +/// Coalesce within each top-level mode independently, preserving mode +/// boundaries. Unlike coalesce(), leaves are never merged across modes. +/// +/// Example: (4, (2, 4)) : (8, (4, 1)) -> (4, 8) : (8, 1). +/// Mode 0 is a leaf (4:8) -- unchanged. +/// Mode 1 sub-leaves (2, 4):(4, 1) are contiguous -- merge to (8):(1). +/// The mode boundary between 4:8 and (2,4):(4,1) is preserved. +/// Compare: coalesce() on the same input gives (32) : (1). +PackMapAttr PackMapAttr::coalesceModes() { + MLIRContext *ctx = getContext(); + SmallVector newShape, newStride; + int64_t rank = getRank(); + for (int64_t i = 0; i < rank; ++i) { + Attribute mShape = getShapeMode(i), mStride = getStrideMode(i); + if (isLeaf(mShape)) { + newShape.push_back(mShape); + newStride.push_back(getLeafValue(mShape) == 1 ? makeLeaf(ctx, 0) + : mStride); + continue; + } + SmallVector> merged = + coalesceImpl(getLeaves(mShape), getLeaves(mStride)); + SmallVector ms, md; + for (auto [s, d] : merged) { + ms.push_back(makeLeaf(ctx, s)); + md.push_back(makeLeaf(ctx, s == 1 ? 0 : d)); + } + newShape.push_back(ms.size() == 1 ? ms[0] : makeTuple(ctx, ms)); + newStride.push_back(md.size() == 1 ? md[0] : makeTuple(ctx, md)); + } + return PackMapAttr::get(ctx, makeTuple(ctx, newShape), + makeTuple(ctx, newStride)); +} + +//===----------------------------------------------------------------------===// +// PackMapAttr - algebra +//===----------------------------------------------------------------------===// + +/// Functional composition: result(c) = this(rhs(c)). +/// +/// A coalesced LHS has leaves that form a mixed-radix decomposition of its +/// index space: the innermost (rightmost) leaf covers [0, s0), the next +/// covers multiples of s0, etc. Each RHS leaf with stride `d` selects every +/// d-th index from LHS. We need to figure out which LHS leaves each RHS leaf +/// "lands on" -- i.e., how to distribute the RHS leaf across the LHS +/// mixed-radix digits. +/// +/// We walk LHS leaves right-to-left (innermost first). For each, we compute +/// newShape = min(max(1, lhsShape / restStride), restShape) +/// This determines how many positions the RHS leaf covers in this LHS digit: +/// - If restStride < lhsShape, the RHS steps within this digit (partial +/// coverage), so newShape = lhsShape / restStride (capped by restShape). +/// - If restStride >= lhsShape, the RHS skips this digit entirely +/// (newShape = 1, pruned from output). +/// The result stride for each piece is restStride * lhsStride, composing +/// the RHS step size with the LHS digit's stride. After each digit, +/// restShape shrinks by the positions consumed, and restStride is divided +/// by the digit's shape to shift into the next digit's scale. +/// +/// Special cases: stride-0 RHS leaves broadcast (always map to index 0). +/// Single-leaf LHS is a simple stride scaling. +PackMapAttr PackMapAttr::compose(PackMapAttr rhs) { + MLIRContext *ctx = getContext(); + PackMapAttr lhs = this->coalesce(); + SmallVector lhsShapes = getLeaves(lhs.getShape()); + SmallVector lhsStrides = getLeaves(lhs.getStride()); + + SmallVector rhsShapes = getLeaves(rhs.getShape()); + SmallVector rhsStrides = getLeaves(rhs.getStride()); + + SmallVector resultShapes, resultStrides; + int n = static_cast(lhsShapes.size()); + + for (auto [rhsS, rhsD] : llvm::zip_equal(rhsShapes, rhsStrides)) { + if (rhsD == 0) { + resultShapes.push_back(makeLeaf(ctx, rhsS)); + resultStrides.push_back(makeLeaf(ctx, 0)); + continue; + } + + if (n == 1) { + resultShapes.push_back(makeLeaf(ctx, rhsS)); + resultStrides.push_back(makeLeaf(ctx, rhsD * lhsStrides[0])); + continue; + } + + SmallVector modeShapes, modeStrides; + int64_t restShape = rhsS; + int64_t restStride = rhsD; + + for (int j = n - 1; j >= 1; --j) { + int64_t currShape = lhsShapes[j]; + int64_t currStride = lhsStrides[j]; + + assert((currShape % restStride == 0 || restStride % currShape == 0) && + "compose: RHS stride must divide LHS shape or vice versa"); + + int64_t newShape = + std::min(std::max(int64_t(1), currShape / restStride), restShape); + + if (newShape != 1) { + modeShapes.push_back(makeLeaf(ctx, newShape)); + modeStrides.push_back(makeLeaf(ctx, restStride * currStride)); + } + + restShape = restShape / newShape; + restStride = llvm::divideCeilSigned(restStride, currShape); + } + + if (restShape != 1 || modeShapes.empty()) { + modeShapes.push_back(makeLeaf(ctx, restShape)); + modeStrides.push_back(makeLeaf(ctx, restStride * lhsStrides[0])); + } + + std::reverse(modeShapes.begin(), modeShapes.end()); + std::reverse(modeStrides.begin(), modeStrides.end()); + + if (modeShapes.size() == 1) { + resultShapes.push_back(modeShapes[0]); + resultStrides.push_back(modeStrides[0]); + } else { + resultShapes.push_back(makeTuple(ctx, modeShapes)); + resultStrides.push_back(makeTuple(ctx, modeStrides)); + } + } + + return PackMapAttr::get(ctx, makeTuple(ctx, resultShapes), + makeTuple(ctx, resultStrides)); +} + +/// Given a range [0, cotarget) and layout A, find layout B such that +/// logicalProduct(A, B) bijects onto [0, cotarget) -- i.e. A and B together +/// partition the entire range with no gaps and no overlaps. +/// +/// The idea: sort A's modes by stride (finest first) and scan left-to-right, +/// tracking `accumulated` = the size of the contiguous index block covered so +/// far. When the next mode has stride d > accumulated, there is a gap +/// [accumulated, d) that A never hits. We fill it with a complement mode +/// (d/accumulated, accumulated) -- stride accumulated steps over the already- +/// covered block, and shape d/accumulated repeats it enough times to reach d. +/// After each mode (s, d), the covered frontier advances to d*s. Finally, a +/// trailing mode covers from the last frontier up to cotarget. +/// +/// Example: A = (4):(1), cotarget = 16. +/// A covers {0, 1, 2, 3}. accumulated = 4. Trailing: 16/4 = 4 -> emit (4, 4). +/// Result: (4):(4), which covers {0, 4, 8, 12}. +/// Together A and B partition {0, ..., 15}. +PackMapAttr PackMapAttr::complement(int64_t cotarget) { + MLIRContext *ctx = getContext(); + + auto [filtShape, filtStride] = filterZeros(ctx, getShape(), getStride()); + SmallVector shapes = getLeaves(filtShape); + SmallVector strides = getLeaves(filtStride); + + SmallVector> modes; + for (auto [s, d] : llvm::zip_equal(shapes, strides)) { + modes.push_back({s, d}); + } + llvm::sort(modes, llvm::less_second()); + + SmallVector compShape, compStride; + int64_t accumulated = 1; + for (auto [s, d] : modes) { + int64_t gap = d / accumulated; + if (gap > 1) { + compShape.push_back(makeLeaf(ctx, gap)); + compStride.push_back(makeLeaf(ctx, accumulated)); + } + accumulated = d * s; + } + + { + int64_t remaining = llvm::divideCeilSigned(cotarget, accumulated); + compShape.push_back(makeLeaf(ctx, remaining)); + compStride.push_back(makeLeaf(ctx, accumulated)); + } + + std::reverse(compShape.begin(), compShape.end()); + std::reverse(compStride.begin(), compStride.end()); + + auto result = PackMapAttr::get(ctx, makeTuple(ctx, compShape), + makeTuple(ctx, compStride)); + return result.coalesce(); +} + +/// Split a layout into two modes: inner (within a tile) and outer (which tile). +/// +/// The tiler defines the shape of one tile. The result is a rank-2 layout +/// where mode 0 is the inner view (position within a tile) and mode 1 is the +/// outer view (which tile). Evaluating with (inner_coord, outer_coord) gives +/// the same index as the original layout evaluated at the corresponding +/// flat position. +/// +/// Example 1: (32):(1) divided by tiler (4):(1) +/// 32 elements split into 8 tiles of 4. +/// Result: ((4), (8)) : ((1), (4)) +/// evaluate({2, 3}) = 2*1 + 3*4 = 14 (element 2 of tile 3). +/// +/// Example 2: (32):(2) divided by tiler (4):(1) +/// Stride-2 layout (hits even indices) split into tiles of 4. +/// Result: ((4), (8)) : ((2), (8)) +/// evaluate({1, 2}) = 1*2 + 2*8 = 18. +PackMapAttr PackMapAttr::logicalDivide(PackMapAttr tiler) { + MLIRContext *ctx = getContext(); + + PackMapAttr coal = this->coalesce(); + PackMapAttr tilerComp = tiler.complement(coal.getSize()); + + PackMapAttr innerResult = coal.compose(tiler); + PackMapAttr outerResult = coal.compose(tilerComp); + + SmallVector resShape = {innerResult.getShape(), + outerResult.getShape()}; + SmallVector resStride = {innerResult.getStride(), + outerResult.getStride()}; + return PackMapAttr::get(ctx, makeTuple(ctx, resShape), + makeTuple(ctx, resStride)); +} + +PackMapAttr PackMapAttr::permute(ArrayRef perm) { + MLIRContext *ctx = getContext(); + SmallVector newShape, newStride; + for (int64_t i : perm) { + newShape.push_back(getShapeMode(i)); + newStride.push_back(getStrideMode(i)); + } + return PackMapAttr::get(ctx, makeTuple(ctx, newShape), + makeTuple(ctx, newStride)); +} + +PackMapAttr PackMapAttr::project(ArrayRef droppedDims) { + MLIRContext *ctx = getContext(); + SmallVector newShape, newStride; + for (auto [i, dropped] : llvm::enumerate(droppedDims)) { + if (!dropped) { + newShape.push_back(getShapeMode(i)); + newStride.push_back(getStrideMode(i)); + } + } + return PackMapAttr::get(ctx, makeTuple(ctx, newShape), + makeTuple(ctx, newStride)); +} + +/// Replicate this layout across multiple copies parameterized by a tiler. +/// +/// The result is a rank-2 layout where mode 0 is the original layout (position +/// within one copy) and mode 1 selects which copy using the tiler's coordinate +/// system. The total index space covered is size * tiler.cosize. +/// +/// Example 1: (4):(1) product with tiler (8):(1) +/// 4 elements replicated 8 times -> 32 total. +/// Result: ((4), (8)) : ((1), (4)) +/// evaluate({2, 3}) = 2*1 + 3*4 = 14 (element 2 of copy 3). +/// +/// Example 2: (4):(1) product with tiler (4):(2) +/// 4 elements replicated with stride-2 tiler -> cosize = (4-1)*2+1 = 7. +/// Result: ((4), (4)) : ((1), (4)) +/// evaluate({1, 2}) = 1*1 + 2*4 = 9. +PackMapAttr PackMapAttr::logicalProduct(PackMapAttr tiler) { + MLIRContext *ctx = getContext(); + + int64_t target = getSize() * tiler.getCosize(); + PackMapAttr comp = this->complement(target); + PackMapAttr mode1 = comp.compose(tiler); + + SmallVector combinedShape = {getShape(), mode1.getShape()}; + SmallVector combinedStride = {getStride(), mode1.getStride()}; + return PackMapAttr::get(ctx, makeTuple(ctx, combinedShape), + makeTuple(ctx, combinedStride)); +} + +PackMapAttr PackMapAttr::filter() { + MLIRContext *ctx = getContext(); + auto [filtShape, filtStride] = filterZeros(ctx, getShape(), getStride()); + return PackMapAttr::get(ctx, filtShape, filtStride).coalesce(); +} + +/// Find R such that A(R(x)) = x for all x in [0, injective range). +/// +/// R has the same shape as A but with natural row-major strides. This means R +/// unpacks a flat index into digits, and A re-packs them in its own stride +/// order -- the two cancel out and yield the original index. +/// +/// If A's smallest stride is > 1 (its output range has a gap at 0), no right +/// inverse exists and the result is the trivial layout (1):(0). +/// +/// Example 1: A = (4, 8):(8, 1) (row-major is an identity) +/// R = (4, 8):(8, 1). A(R(x)) = x. +/// +/// Example 2: A = (4, 2):(1, 4) (column-major 4x2) +/// R = (2, 4):(1, 2). +/// R(5) = 3, A(3) = 5. +PackMapAttr PackMapAttr::rightInverse() { + MLIRContext *ctx = getContext(); + SmallVector shapes = getLeaves(getShape()); + SmallVector strides = getLeaves(getStride()); + + int n = static_cast(shapes.size()); + SmallVector rStrides(n); + { + int64_t acc = 1; + for (int i = n - 1; i >= 0; --i) { + rStrides[i] = acc; + acc *= shapes[i]; + } + } + + SmallVector> sorted; + for (int i = 0; i < n; ++i) { + sorted.push_back({strides[i], shapes[i], rStrides[i]}); + } + llvm::sort(sorted, llvm::less_first()); + + SmallVector resShapes, resStrides; + int64_t currentIdx = 1; + for (auto [stride, shape, rStride] : sorted) { + if (shape == 1) { + continue; + } + if (currentIdx != stride) { + break; + } + resShapes.push_back(makeLeaf(ctx, shape)); + resStrides.push_back(makeLeaf(ctx, rStride)); + currentIdx = shape * stride; + } + + if (resShapes.empty()) { + resShapes.push_back(makeLeaf(ctx, 1)); + resStrides.push_back(makeLeaf(ctx, 0)); + } + + std::reverse(resShapes.begin(), resShapes.end()); + std::reverse(resStrides.begin(), resStrides.end()); + + return PackMapAttr::get(ctx, makeTuple(ctx, resShapes), + makeTuple(ctx, resStrides)) + .coalesce(); +} + +/// Find L such that L(A(x)) = x for all x in [0, size). +/// +/// Unlike rightInverse (which requires A's output range to be contiguous), +/// leftInverse works even when A's outputs have gaps. It uses complement to +/// fill the gaps, building a combined layout that covers [0, size) fully, +/// then takes the rightInverse of that. On A's actual outputs, this recovers x. +/// +/// Example 1: A = (4):(2) (maps {0,1,2,3} -> {0,2,4,6}) +/// L = (4, 2):(1, 4). +/// A(3) = 6. L(6): unpack 6 in shape (4,2) -> (3,0) -> 3*1 + 0*4 = 3. +/// +/// Example 2: A = (8):(2) (maps {0..7} -> {0,2,4,6,8,10,12,14}) +/// L = (8, 2):(1, 8). +/// A(5) = 10. L(10): unpack 10 in shape (8,2) -> (5,0) -> 5*1 + 0*8 = 5. +PackMapAttr PackMapAttr::leftInverse() { + MLIRContext *ctx = getContext(); + PackMapAttr comp = this->complement(getSize()); + + SmallVector combinedShape = {comp.getShape(), getShape()}; + SmallVector combinedStride = {comp.getStride(), getStride()}; + PackMapAttr combined = PackMapAttr::get(ctx, makeTuple(ctx, combinedShape), + makeTuple(ctx, combinedStride)); + return combined.rightInverse(); +} + +//===----------------------------------------------------------------------===// +// PackMapAttr - tiled divide and product +//===----------------------------------------------------------------------===// + +/// Take a rank-2 result and flatten mode 1 into top-level modes. +static PackMapAttr flattenRestModes(PackMapAttr divided) { + assert(divided.getRank() == 2 && "expected rank-2 layout"); + MLIRContext *ctx = divided.getContext(); + SmallVector newShape = {divided.getShapeMode(0)}; + SmallVector newStride = {divided.getStrideMode(0)}; + Attribute restShape = divided.getShapeMode(1); + Attribute restStride = divided.getStrideMode(1); + if (isLeaf(restShape)) { + newShape.push_back(restShape); + newStride.push_back(restStride); + } else { + for (auto [s, d] : llvm::zip_equal(cast(restShape), + cast(restStride))) { + newShape.push_back(s); + newStride.push_back(d); + } + } + return PackMapAttr::get(ctx, makeTuple(ctx, newShape), + makeTuple(ctx, newStride)); +} + +PackMapAttr PackMapAttr::tiledDivide(PackMapAttr tiler) { + return flattenRestModes(logicalDivide(tiler)); +} + +PackMapAttr PackMapAttr::tiledProduct(PackMapAttr tiler) { + return flattenRestModes(logicalProduct(tiler)); +} + +//===----------------------------------------------------------------------===// +// PackMapAttr - factory +//===----------------------------------------------------------------------===// + +/// Create the row-major identity layout: strides are suffix products. +/// shape (M, N, K) -> (M, N, K) : (N*K, K, 1). +PackMapAttr PackMapAttr::makeIdentity(MLIRContext *ctx, + ArrayRef shape) { + SmallVector leaves; + for (int64_t s : shape) { + leaves.push_back(makeLeaf(ctx, s)); + } + Attribute shapeAttr = makeTuple(ctx, leaves); + return PackMapAttr::get(ctx, shapeAttr, suffixProduct(ctx, shapeAttr)); +} + +//===----------------------------------------------------------------------===// +// Dialect attribute registration +//===----------------------------------------------------------------------===// + +void IREEMapDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen generated definitions +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.cpp.inc" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h new file mode 100644 index 000000000000..036aa1dc1ff3 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h @@ -0,0 +1,18 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_IREEMAPATTRS_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_IREEMAPATTRS_H_ + +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h" +#include "mlir/IR/BuiltinAttributes.h" + +// clang-format off: must be included after all LLVM/MLIR headers. +#define GET_ATTRDEF_CLASSES +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h.inc" +// clang-format on + +#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_IREEMAPATTRS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.td new file mode 100644 index 000000000000..e4bf08441405 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.td @@ -0,0 +1,196 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_CODEGEN_DIALECT_MAP_IREEMAP_ATTRS +#define IREE_CODEGEN_DIALECT_MAP_IREEMAP_ATTRS + +include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapBase.td" + +//===----------------------------------------------------------------------===// +// PackMapAttr +//===----------------------------------------------------------------------===// + +def IREEMap_PackMapAttr : IREEMap_Attr<"PackMap"> { + let mnemonic = "pack_map"; + + let summary = [{A layout map: a (Shape, Stride) pair mapping coordinates to indices.}]; + + let description = [{ + A layout map is a function from coordinates to 1-D indices, defined by a + `(Shape, Stride)` pair of hierarchical IntTuples. + + #### IntTuple + + An IntTuple is a recursive type: either a single integer or a tuple of + IntTuples. For example: `4`, `(2, 3)`, `((2, 4), 8)`. + + All values are static i64 constants. + + #### Coordinate-to-index mapping + + A layout maps a coordinate to a 1-D index by pairing each shape leaf with + its corresponding stride leaf and computing a weighted sum. The coordinate + is decomposed into per-leaf natural coordinates (like mixed-radix digits), + then each is multiplied by its stride and summed. + + For a flat layout `(S0, S1) : (D0, D1)`, coordinate `(c0, c1)` maps to: + + ``` + index = c0 * D0 + c1 * D1 + ``` + + **Example: row-major `(4, 8) : (8, 1)`** + + ``` + coord (0, 0) → 0*8 + 0*1 = 0 + coord (0, 5) → 0*8 + 5*1 = 5 + coord (2, 3) → 2*8 + 3*1 = 19 + coord (3, 7) → 3*8 + 7*1 = 31 + + col 0 col 1 col 2 ... col 7 + row 0: 0 1 2 ... 7 + row 1: 8 9 10 ... 15 + row 2: 16 17 18 ... 23 + row 3: 24 25 26 ... 31 + ``` + + **Example: column-major `(4, 8) : (1, 4)`** + + ``` + coord (0, 0) → 0*1 + 0*4 = 0 + coord (2, 0) → 2*1 + 0*4 = 2 + coord (0, 1) → 0*1 + 1*4 = 4 + coord (3, 7) → 3*1 + 7*4 = 31 + ``` + + **Broadcast (stride 0)** + + A stride of `0` means the coordinate for that dimension is ignored — the + dimension is broadcast. Any value along that axis maps to the same index. + + **Example: `(4, 8) : (0, 1)`** — dimension 0 is broadcast + + ``` + coord (0, 3) → 0*0 + 3*1 = 3 + coord (2, 3) → 2*0 + 3*1 = 3 // same index regardless of row + ``` + + **Example: hierarchical `((2, 4), 8) : ((16, 1), 4)`** + + Hierarchical shapes nest sub-modes within a mode. Coordinate `(c0, c1)` + first decomposes `c0` into sub-coordinates for the inner tuple `(2, 4)`, + then computes the weighted sum across all leaves: + + ``` + coord (5, 3) → c0=5 decomposes as (1, 1) in shape (2, 4) + → 1*16 + 1*1 + 3*4 = 29 + ``` + + #### Ordering convention + + This implementation uses **lexicographic (row-major)** ordering: the + rightmost mode is the innermost (fastest-varying) dimension. The default + stride for shape `(M, N, K)` is the suffix product `(N*K, K, 1)`. + + #### Congruence + + Shape and stride must be **congruent**: they have identical hierarchical + tree structure. + + #### Assembly format + + ```mlir + #iree_map.pack_map<(8) : (1)> // rank-1 contiguous + #iree_map.pack_map<(4, 8) : (8, 1)> // rank-2 row-major + #iree_map.pack_map<(4, 8) : (1, 4)> // rank-2 column-major + #iree_map.pack_map<(4, 8) : (0, 1)> // dim 0 broadcast + #iree_map.pack_map<((2, 4), 8) : ((16, 1), 4)> // hierarchical + ``` + }]; + + let parameters = (ins + "::mlir::Attribute":$shape, + "::mlir::Attribute":$stride + ); + + let assemblyFormat = "`<` custom($shape) `:` custom($stride) `>`"; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + // --- Properties --- + /// Number of modes (top-level tuples) in the shape/stride. + int64_t getRank(); + /// Maximum nesting depth of the shape/stride tuple. A flat layout has depth 1. + int64_t getDepth(); + /// Product of all leaf shape values. May overflow for large layouts. + int64_t getSize(); + /// Maximum index + 1: cosize = sum((s-1)*d for each leaf) + 1. + int64_t getCosize(); + + // --- Mode access --- + // Get the shape tuple for mode i. + ::mlir::Attribute getShapeMode(int64_t i); + // Get the stride tuple for mode i. + ::mlir::Attribute getStrideMode(int64_t i); + + // --- Evaluation --- + + /// Evaluate the layout function: coord -> index. + /// + /// Accepts either a single flat index (converted via idx2crd) or a full + /// per-leaf coordinate vector. Both are reduced to a weighted sum with the + /// stride leaves. + int64_t evaluate(::llvm::ArrayRef coord); + + // --- Simplification --- + + /// Merge adjacent leaves that form a contiguous range. Does not change the + /// layout function, only simplifies the representation. + PackMapAttr coalesce(); + /// Coalesce within each top-level mode independently, preserving the + /// mode boundary. Unlike coalesce(), leaves are not merged across modes. + PackMapAttr coalesceModes(); + /// Flatten all hierarchy into a single-level tuple of leaves. + PackMapAttr flatten(); + + // --- Algebra --- + + /// Functional composition: result(c) = this(rhs(c)). + PackMapAttr compose(PackMapAttr rhs); + /// Find a layout B such that A and B together tile [0, cotarget). + PackMapAttr complement(int64_t cotarget); + /// Factor layout into (inner, outer) using a tiler: inner covers the tile, + /// outer covers the rest. + PackMapAttr logicalDivide(PackMapAttr tiler); + /// Replicate this layout using a tiler pattern. + PackMapAttr logicalProduct(PackMapAttr tiler); + /// Remove trivial modes (size-1 or stride-0 leaves), then coalesce. + /// Returns (1):(0) if all modes are trivial. + PackMapAttr filter(); + /// Find R such that A(R(x)) = x for all x in the injective range. + PackMapAttr rightInverse(); + /// Find L such that L(A(x)) = x for all x in [0, size). + PackMapAttr leftInverse(); + /// Like logicalDivide, but flattens the outer modes to top level. + PackMapAttr tiledDivide(PackMapAttr tiler); + /// Like logicalProduct, but flattens the replicated modes to top level. + PackMapAttr tiledProduct(PackMapAttr tiler); + + // --- Mode operations --- + /// Reorder modes: result mode i = original mode perm[i]. + PackMapAttr permute(::llvm::ArrayRef perm); + /// Drop modes where droppedDims[i] is true. + PackMapAttr project(::llvm::ArrayRef droppedDims); + + // --- Factory --- + /// Create the row-major (lexicographic) identity layout for a flat shape. + /// For shape (M, N, K) produces (M, N, K) : (N*K, K, 1). + static PackMapAttr makeIdentity(::mlir::MLIRContext *ctx, + ::llvm::ArrayRef shape); + }]; +} + +#endif // IREE_CODEGEN_DIALECT_MAP_IREEMAP_ATTRS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapBase.td b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapBase.td new file mode 100644 index 000000000000..4ec7df8349e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapBase.td @@ -0,0 +1,45 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_CODEGEN_DIALECT_MAP_IREEMAP_BASE +#define IREE_CODEGEN_DIALECT_MAP_IREEMAP_BASE + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// IREE Map Dialect +//===----------------------------------------------------------------------===// + +def IREEMap_Dialect : Dialect { + let name = "iree_map"; + let cppNamespace = "::mlir::iree_compiler::IREE::Map"; + + let summary = [{ + A dialect for map algebra used in code generation. + }]; + let description = [{ + This dialect provides map attributes for describing hierarchical coordinate + mappings used in codegen. The base of the dialect is the PackMap attribute + which defines a coordinate to index mapping and defines a closed algebra + over it. This algebra allows us to represent complex coordinate + transformations as algebra operations. + }]; + + let useDefaultAttributePrinterParser = 1; + let extraClassDeclaration = [{ + void registerAttributes(); + }]; +} + +//===----------------------------------------------------------------------===// +// Attribute base class +//===----------------------------------------------------------------------===// + +class IREEMap_Attr traits = []> + : AttrDef; + +#endif // IREE_CODEGEN_DIALECT_MAP_IREEMAP_BASE diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.cpp new file mode 100644 index 000000000000..0698d9e54bb1 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h" + +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h" +#include "mlir/IR/DialectImplementation.h" + +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.cpp.inc" + +namespace mlir::iree_compiler::IREE::Map { + +struct IREEMapDialectOpAsmInterface final : OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (isa(attr)) { + os << "pack_map"; + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; + } +}; + +void IREEMapDialect::initialize() { + addInterfaces(); + registerAttributes(); +} + +} // namespace mlir::iree_compiler::IREE::Map diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h new file mode 100644 index 000000000000..a1b2ddff8d1f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h @@ -0,0 +1,16 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_IREEMAPDIALECT_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_IREEMAPDIALECT_H_ + +#include "mlir/IR/Dialect.h" + +// clang-format off: must be included after all LLVM/MLIR headers. +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h.inc" +// clang-format on + +#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_IREEMAPDIALECT_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IntTuple.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IntTuple.cpp new file mode 100644 index 000000000000..6ea3e1caea28 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IntTuple.cpp @@ -0,0 +1,301 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h" + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +namespace mlir::iree_compiler::IREE::Map { + +// --- Query functions --- + +bool isIntTuple(Attribute attr) { + if (!attr) { + return false; + } + if (auto intAttr = dyn_cast(attr)) { + return intAttr.getType().isInteger(64); + } + if (auto arrAttr = dyn_cast(attr)) { + return llvm::all_of(arrAttr, isIntTuple); + } + return false; +} + +bool isLeaf(Attribute attr) { return isa(attr); } + +int64_t getLeafValue(Attribute attr) { + return cast(attr).getInt(); +} + +int64_t getRank(Attribute attr) { + if (isLeaf(attr)) { + return 1; + } + return cast(attr).size(); +} + +int64_t getDepth(Attribute attr) { + if (isLeaf(attr)) { + return 0; + } + int64_t maxChild = 0; + for (Attribute child : cast(attr)) { + maxChild = std::max(maxChild, getDepth(child)); + } + return 1 + maxChild; +} + +int64_t getSize(Attribute attr) { + if (isLeaf(attr)) { + return getLeafValue(attr); + } + int64_t product = 1; + for (Attribute child : cast(attr)) { + product *= getSize(child); + } + return product; +} + +Attribute getElement(Attribute attr, int64_t i) { + if (isLeaf(attr)) { + assert(i == 0 && "leaf only has element 0"); + return attr; + } + return cast(attr)[i]; +} + +// --- Predicates --- + +bool isCongruent(Attribute a, Attribute b) { + if (isLeaf(a) && isLeaf(b)) { + return true; + } + if (isLeaf(a) != isLeaf(b)) { + return false; + } + auto arrA = cast(a); + auto arrB = cast(b); + // all_of_zip returns false if arrays have different sizes. + return llvm::all_of_zip(arrA, arrB, [](Attribute ea, Attribute eb) { + return isCongruent(ea, eb); + }); +} + +// --- Builders --- + +Attribute makeLeaf(MLIRContext *ctx, int64_t val) { + return IntegerAttr::get(IntegerType::get(ctx, 64), val); +} + +Attribute makeTuple(MLIRContext *ctx, ArrayRef elements) { + return ArrayAttr::get(ctx, elements); +} + +Attribute simplify(Attribute attr) { + if (isLeaf(attr)) { + return attr; + } + auto arr = cast(attr); + if (arr.size() == 1) { + return simplify(arr[0]); + } + SmallVector result; + for (Attribute elem : arr) { + result.push_back(simplify(elem)); + } + return makeTuple(attr.getContext(), result); +} + +Attribute flatten(MLIRContext *ctx, Attribute tuple) { + SmallVector leafAttrs = llvm::map_to_vector( + getLeaves(tuple), [&](int64_t v) { return makeLeaf(ctx, v); }); + return makeTuple(ctx, leafAttrs); +} + +// --- Arithmetic --- + +int64_t innerProduct(Attribute coord, Attribute stride) { + if (isLeaf(coord)) { + assert(isLeaf(stride)); + return getLeafValue(coord) * getLeafValue(stride); + } + auto coordArr = cast(coord); + auto strideArr = cast(stride); + int64_t sum = 0; + for (auto [c, s] : llvm::zip_equal(coordArr, strideArr)) { + sum += innerProduct(c, s); + } + return sum; +} + +/// Divide a shape by a divisor, distributing the division left-to-right +/// across the hierarchical structure (outermost mode first). +/// +/// For a leaf: if leaf is divisible by divisor, return leaf / divisor. +/// Otherwise the divisor is larger than the leaf, so the leaf is fully +/// consumed (returns 1) and the remaining divisor carries forward. +/// +/// For a tuple: fold left-to-right with a running remainder. Each child +/// either absorbs the remainder fully, partially, or is untouched once +/// the remainder is exhausted. Left-to-right is correct because in +/// lexicographic ordering, leftmost modes are outermost -- dividing +/// removes the outermost coordinates first. +Attribute shapeDiv(MLIRContext *ctx, Attribute shape, int64_t divisor) { + if (divisor == 1) { + return shape; + } + + if (isLeaf(shape)) { + int64_t s = getLeafValue(shape); + if (s % divisor == 0) { + return makeLeaf(ctx, s / divisor); + } + assert(divisor % s == 0 && "shapeDiv: divisibility constraint violated"); + return makeLeaf(ctx, 1); + } + + // Tuple: fold left-to-right with running remainder. + auto arr = cast(shape); + SmallVector result; + int64_t rem = divisor; + for (Attribute child : arr) { + int64_t childSize = getSize(child); + if (rem == 1) { + result.push_back(child); + } else if (childSize % rem == 0) { + result.push_back(shapeDiv(ctx, child, rem)); + rem = 1; + } else { + assert(rem % childSize == 0 && + "shapeDiv: divisibility constraint violated"); + result.push_back(shapeDiv(ctx, child, childSize)); + rem /= childSize; + } + } + return makeTuple(ctx, result); +} + +Attribute suffixProduct(MLIRContext *ctx, Attribute shape) { + SmallVector leaves = getLeaves(shape); + int n = static_cast(leaves.size()); + SmallVector result(n); + int64_t acc = 1; + for (int i = n - 1; i >= 0; --i) { + result[i] = acc; + acc *= leaves[i]; + } + SmallVector attrs = + llvm::map_to_vector(result, [&](int64_t v) { return makeLeaf(ctx, v); }); + return makeTuple(ctx, attrs); +} + +// --- Coordinate conversion --- + +// Lexicographic (row-major): rightmost dimension varies fastest. +SmallVector idx2crd(int64_t idx, Attribute shape) { + SmallVector leaves = getLeaves(shape); + int n = static_cast(leaves.size()); + SmallVector coords(n); + int64_t remaining = idx; + for (int i = n - 1; i >= 0; --i) { + coords[i] = remaining % leaves[i]; + remaining /= leaves[i]; + } + return coords; +} + +int64_t crd2idx(ArrayRef coord, Attribute stride) { + SmallVector strides = getLeaves(stride); + int64_t result = 0; + for (auto [c, s] : llvm::zip_equal(coord, strides)) { + result += c * s; + } + return result; +} + +// --- Filtering --- + +std::pair filterZeros(MLIRContext *ctx, Attribute shape, + Attribute stride) { + SmallVector flatShape = getLeaves(shape); + SmallVector flatStride = getLeaves(stride); + assert(flatShape.size() == flatStride.size()); + + SmallVector filteredShape, filteredStride; + for (auto [s, d] : llvm::zip_equal(flatShape, flatStride)) { + if (d != 0 && s != 1) { + filteredShape.push_back(makeLeaf(ctx, s)); + filteredStride.push_back(makeLeaf(ctx, d)); + } + } + + if (filteredShape.empty()) { + filteredShape.push_back(makeLeaf(ctx, 1)); + filteredStride.push_back(makeLeaf(ctx, 0)); + } + + return {makeTuple(ctx, filteredShape), makeTuple(ctx, filteredStride)}; +} + +SmallVector getLeaves(Attribute attr) { + SmallVector result; + if (isLeaf(attr)) { + result.push_back(getLeafValue(attr)); + return result; + } + for (Attribute child : cast(attr)) { + SmallVector childLeaves = getLeaves(child); + llvm::append_range(result, childLeaves); + } + return result; +} + +// --- Leaf info --- + +SmallVector getLeafInfos(Attribute shape, Attribute stride) { + SmallVector shapes = getLeaves(shape); + SmallVector strides = getLeaves(stride); + assert(shapes.size() == strides.size()); + int n = static_cast(shapes.size()); + + SmallVector dataStrides(n, 1); + for (int i = n - 2; i >= 0; --i) { + dataStrides[i] = dataStrides[i + 1] * shapes[i + 1]; + } + + SmallVector result; + for (int i = 0; i < n; ++i) { + result.push_back({shapes[i], strides[i], dataStrides[i]}); + } + return result; +} + +SmallVector +filterLeafInfos(Attribute shape, Attribute stride, + llvm::function_ref pred) { + SmallVector result; + for (const LeafInfo &leaf : getLeafInfos(shape, stride)) { + if (pred(leaf)) { + result.push_back(leaf); + } + } + return result; +} + +int64_t +foldLeafInfos(Attribute shape, Attribute stride, int64_t init, + llvm::function_ref fn) { + int64_t acc = init; + for (const LeafInfo &leaf : getLeafInfos(shape, stride)) { + acc = fn(acc, leaf); + } + return acc; +} + +} // namespace mlir::iree_compiler::IREE::Map diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h new file mode 100644 index 000000000000..757cef23319c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h @@ -0,0 +1,125 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// IntTuple utilities. +// +// An IntTuple is a recursive type: either a single integer (IntegerAttr, i64) +// or a tuple of IntTuples (ArrayAttr). These utilities operate on Attribute +// trees encoding this structure. + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_INTTUPLE_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_INTTUPLE_H_ + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" + +namespace mlir::iree_compiler::IREE::Map { + +// --- Query functions --- + +/// Returns true if attr is a valid IntTuple (IntegerAttr or ArrayAttr of +/// IntTuples, recursively). +bool isIntTuple(Attribute attr); + +/// Returns true if attr is a leaf integer (IntegerAttr). +bool isLeaf(Attribute attr); + +/// Returns the integer value of a leaf. Requires isLeaf(attr). +int64_t getLeafValue(Attribute attr); + +/// Top-level element count. 1 for a leaf, N for an N-element ArrayAttr. +int64_t getRank(Attribute attr); + +/// Maximum nesting depth. 0 for a leaf, 1 + max(child depths) for a tuple. +int64_t getDepth(Attribute attr); + +/// Product of all leaf integers. May overflow for large layouts. +int64_t getSize(Attribute attr); + +/// Get the i-th top-level element. For a leaf, i must be 0 (returns self). +Attribute getElement(Attribute attr, int64_t i); + +/// Collect all leaves from an IntTuple into a flat vector. +SmallVector getLeaves(Attribute attr); + +// --- Predicates --- + +/// Returns true if a and b have identical tree structure (same nesting, same +/// ranks at each level). Leaf values may differ. +bool isCongruent(Attribute a, Attribute b); + +// --- Builders --- + +/// Create a leaf IntegerAttr(i64). +Attribute makeLeaf(MLIRContext *ctx, int64_t val); + +/// Create a tuple (ArrayAttr) from elements. +Attribute makeTuple(MLIRContext *ctx, ArrayRef elements); + +/// Flatten all nesting into a single-level tuple of leaves. +Attribute flatten(MLIRContext *ctx, Attribute tuple); + +/// Recursively unwrap single-element tuples: (x) → simplify(x). +Attribute simplify(Attribute attr); + +// --- Arithmetic --- + +/// Recursive inner product: sum of (leaf_coord * leaf_stride) over all leaves. +int64_t innerProduct(Attribute coord, Attribute stride); + +/// shape_div: divide shape by divisor, distributing left-to-right. +Attribute shapeDiv(MLIRContext *ctx, Attribute shape, int64_t divisor); + +/// suffixProduct: compute row-major (lexicographic) strides for a flat shape. +/// For shape (M, N, K) returns (N*K, K, 1) as an ArrayAttr of i64 leaves. +Attribute suffixProduct(MLIRContext *ctx, Attribute shape); + +// --- Coordinate conversion --- + +/// idx2crd: convert a 1-D index to a natural coordinate matching shape. +/// Uses lexicographic (row-major) ordering. +SmallVector idx2crd(int64_t idx, Attribute shape); + +/// crd2idx: convert a coordinate to a 1-D index via inner product with stride. +int64_t crd2idx(ArrayRef coord, Attribute stride); + +// --- Filtering --- + +/// Filter stride-0 and size-1 modes from parallel shape+stride tuples. +/// Returns (filteredShape, filteredStride) as a pair of attributes. +std::pair filterZeros(MLIRContext *ctx, Attribute shape, + Attribute stride); + +// --- Leaf info --- + +/// Info about a single leaf in a (shape, stride) mode pair. +/// `stride` is the layout stride (0 = broadcast). +/// `dataStride` is the lex data stride (product of all subsequent leaf sizes). +struct LeafInfo { + int64_t size = {}; + int64_t stride = {}; + int64_t dataStride = {}; +}; + +/// Walk leaves of parallel (shape, stride), computing lex data strides. +SmallVector getLeafInfos(Attribute shape, Attribute stride); + +/// Filter leaf infos matching a predicate. +SmallVector +filterLeafInfos(Attribute shape, Attribute stride, + llvm::function_ref pred); + +/// Fold over leaf infos with an accumulator. +/// fn receives (accumulator, LeafInfo) and returns the new accumulator. +int64_t +foldLeafInfos(Attribute shape, Attribute stride, int64_t init, + llvm::function_ref fn); + +} // namespace mlir::iree_compiler::IREE::Map + +#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_IR_INTTUPLE_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/BUILD.bazel new file mode 100644 index 000000000000..67448eded56b --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/BUILD.bazel @@ -0,0 +1,25 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = [ + "invalid.mlir", + "roundtrip.mlir", + ], + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/CMakeLists.txt new file mode 100644 index 000000000000..66e790f1679d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/CMakeLists.txt @@ -0,0 +1,24 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "invalid.mlir" + "roundtrip.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/invalid.mlir b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/invalid.mlir new file mode 100644 index 000000000000..ca44a6be8347 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/invalid.mlir @@ -0,0 +1,34 @@ +// RUN: iree-opt --split-input-file --verify-diagnostics %s + +// expected-error @+1 {{shape leaf values must be positive, got 0}} +func.func @zero_shape() attributes {layout = #iree_map.pack_map<(0, 8) : (1, 4)>} { + return +} + +// ----- + +// expected-error @+1 {{shape leaf values must be positive, got -1}} +func.func @negative_shape() attributes {layout = #iree_map.pack_map<(-1, 8) : (1, 4)>} { + return +} + +// ----- + +// expected-error @+1 {{stride leaf values must be non-negative, got -1}} +func.func @negative_stride() attributes {layout = #iree_map.pack_map<(4, 8) : (-1, 4)>} { + return +} + +// ----- + +// expected-error @+1 {{shape and stride must be congruent}} +func.func @non_congruent_stride_nested() attributes {layout = #iree_map.pack_map<(4, 8) : (1, (2, 3))>} { + return +} + +// ----- + +// expected-error @+1 {{shape and stride must be congruent}} +func.func @non_congruent_shape_nested() attributes {layout = #iree_map.pack_map<((2, 4), 8) : (1, 4)>} { + return +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/roundtrip.mlir new file mode 100644 index 000000000000..5db57cec371e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/test/roundtrip.mlir @@ -0,0 +1,91 @@ +// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s + +// ============================================================================ +// pack_map roundtrips (no coalescing — preserves structure exactly) +// ============================================================================ + +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(8) : (1)> +// CHECK-LABEL: func @pack_map_rank1_contiguous +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_rank1_contiguous() attributes {layout = #iree_map.pack_map<(8) : (1)>} { + return +} + +// ----- + +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(8) : (2)> +// CHECK-LABEL: func @pack_map_rank1_strided +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_rank1_strided() attributes {layout = #iree_map.pack_map<(8) : (2)>} { + return +} + +// ----- + +// Rank-2 column-major. +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(4, 8) : (1, 4)> +// CHECK-LABEL: func @pack_map_rank2_col_major +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_rank2_col_major() attributes {layout = #iree_map.pack_map<(4, 8) : (1, 4)>} { + return +} + +// ----- + +// Rank-2 row-major. +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(4, 8) : (8, 1)> +// CHECK-LABEL: func @pack_map_rank2_row_major +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_rank2_row_major() attributes {layout = #iree_map.pack_map<(4, 8) : (8, 1)>} { + return +} + +// ----- + +// Rank-3. +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(2, 4, 8) : (1, 2, 8)> +// CHECK-LABEL: func @pack_map_rank3 +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_rank3() attributes {layout = #iree_map.pack_map<(2, 4, 8) : (1, 2, 8)>} { + return +} + +// ----- + +// Hierarchical shape and stride (MMA-style). +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<((2, 4), 8) : ((1, 16), 4)> +// CHECK-LABEL: func @pack_map_hierarchical +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_hierarchical() attributes {layout = #iree_map.pack_map<((2, 4), 8) : ((1, 16), 4)>} { + return +} + +// ----- + +// Deeply nested hierarchy — pack_map preserves nesting exactly. +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(((2, 2), 4), 8) : (((1, 4), 16), 2)> +// CHECK-LABEL: func @pack_map_deeply_nested +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_deeply_nested() attributes {layout = #iree_map.pack_map<(((2, 2), 4), 8) : (((1, 4), 16), 2)>} { + return +} + +// ----- + +// Stride-0 (broadcast mode). +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(4, 8) : (0, 1)> +// CHECK-LABEL: func @pack_map_broadcast +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_broadcast() attributes {layout = #iree_map.pack_map<(4, 8) : (0, 1)>} { + return +} + +// ----- + +// Shape-1 mode (unit extent). +// CHECK: #[[$LAYOUT:.+]] = #iree_map.pack_map<(1, 8) : (0, 1)> +// CHECK-LABEL: func @pack_map_unit_extent +// CHECK-SAME: layout = #[[$LAYOUT]] +func.func @pack_map_unit_extent() attributes {layout = #iree_map.pack_map<(1, 8) : (0, 1)>} { + return +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/BUILD.bazel new file mode 100644 index 000000000000..ea32c512390f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/BUILD.bazel @@ -0,0 +1,24 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_test") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_test( + name = "PackMapTest", + testonly = True, + srcs = ["PackMapTest.cpp"], + deps = [ + "//compiler/src/iree/compiler/Codegen/Dialect/Map/IR:IREEMapDialect", + "//compiler/src/iree/testing:gtest_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/CMakeLists.txt new file mode 100644 index 000000000000..05abab5a0c18 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/CMakeLists.txt @@ -0,0 +1,26 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_test( + NAME + PackMapTest + SRCS + "PackMapTest.cpp" + DEPS + MLIRIR + gmock + gtest + iree::compiler::Codegen::Dialect::Map::IR::IREEMapDialect + iree::testing::gtest_main +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/PackMapTest.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/PackMapTest.cpp new file mode 100644 index 000000000000..c51625dd8ab0 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Map/IR/unittests/PackMapTest.cpp @@ -0,0 +1,528 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h" +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h" +#include "iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h" + +namespace mlir::iree_compiler::IREE::Map { +namespace { + +// Recursive helper for building nested IntTuple literals. +// Leaves are constructed from int64_t; tuples from initializer_list. +// Examples: +// T{4} -> leaf(4) +// T{4, 8} -> tuple(leaf(4), leaf(8)) +// T{{2, 4}, 8} -> tuple(tuple(leaf(2), leaf(4)), leaf(8)) +// T{{{2, 4}, 8}, 16} -> depth-3 nesting +struct T { + std::variant> val; + T(int64_t v) : val(v) {} + T(std::initializer_list children) : val(std::vector(children)) {} + + Attribute toAttr(MLIRContext *c) const { + if (auto *v = std::get_if(&val)) { + return makeLeaf(c, *v); + } + SmallVector attrs = + llvm::map_to_vector(std::get>(val), + [&](const T &child) { return child.toAttr(c); }); + return makeTuple(c, attrs); + } +}; + +class PackMapTest : public ::testing::Test { +protected: + PackMapTest() { + DialectRegistry reg; + reg.insert(); + ctx.appendDialectRegistry(reg); + ctx.loadAllAvailableDialects(); + } + + MLIRContext *getContext() { return &ctx; } + + // Create a PackMapAttr from nested shape/stride expressions. + PackMapAttr make(T shape, T stride) { + MLIRContext *c = getContext(); + return PackMapAttr::get(c, shape.toAttr(c), stride.toAttr(c)); + } + +private: + MLIRContext ctx; +}; + +//===----------------------------------------------------------------------===// +// Properties +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, Rank) { + EXPECT_EQ(make({8}, {1}).getRank(), 1); + EXPECT_EQ(make({4, 8}, {8, 1}).getRank(), 2); + EXPECT_EQ(make({2, 3, 4}, {12, 4, 1}).getRank(), 3); +} + +TEST_F(PackMapTest, Depth) { + EXPECT_EQ(make({8}, {1}).getDepth(), 1); + EXPECT_EQ(make({4, 8}, {8, 1}).getDepth(), 1); + // Depth 2: ((2, 4), 8) : ((16, 1), 4) + EXPECT_EQ(make({{2, 4}, 8}, {{16, 1}, 4}).getDepth(), 2); + // Depth 3: (((2, 4), 8), 16) : (((32, 4), 2), 1) + EXPECT_EQ(make({{{2, 4}, 8}, 16}, {{{32, 4}, 2}, 1}).getDepth(), 3); +} + +TEST_F(PackMapTest, Size) { + EXPECT_EQ(make({8}, {1}).getSize(), 8); + EXPECT_EQ(make({4, 8}, {8, 1}).getSize(), 32); + // Nested: ((2, 3), 4) leaves are (2, 3, 4), size = 24. + EXPECT_EQ(make({{2, 3}, 4}, {{12, 1}, 4}).getSize(), 24); +} + +TEST_F(PackMapTest, Cosize) { + // (8) : (1) -> cosize = (8-1)*1 + 1 = 8 + EXPECT_EQ(make({8}, {1}).getCosize(), 8); + // (4, 8) : (8, 1) -> cosize = (4-1)*8 + (8-1)*1 + 1 = 32 + EXPECT_EQ(make({4, 8}, {8, 1}).getCosize(), 32); + // Nested ((4, 8)) : ((1, 4)) -> cosize = (4-1)*1 + (8-1)*4 + 1 = 32 + EXPECT_EQ(make({{4, 8}}, {{1, 4}}).getCosize(), 32); +} + +//===----------------------------------------------------------------------===// +// Mode access +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, ModeAccess) { + auto layout = make({4, 8}, {8, 1}); + EXPECT_EQ(getLeafValue(layout.getShapeMode(0)), 4); + EXPECT_EQ(getLeafValue(layout.getShapeMode(1)), 8); + EXPECT_EQ(getLeafValue(layout.getStrideMode(0)), 8); + EXPECT_EQ(getLeafValue(layout.getStrideMode(1)), 1); +} + +//===----------------------------------------------------------------------===// +// Evaluation +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, EvaluateFlatIndex) { + // (4, 8) : (8, 1), identity row-major + auto layout = make({4, 8}, {8, 1}); + for (int i = 0; i < 32; ++i) { + EXPECT_EQ(layout.evaluate({i}), i); + } +} + +TEST_F(PackMapTest, EvaluateMultiDimCoord) { + // Nested: ((2, 4), 8) : ((16, 1), 4) — leaves (2,4,8) : (16,1,4) + auto layout = make({{2, 4}, 8}, {{16, 1}, 4}); + EXPECT_EQ(layout.evaluate({0, 0, 0}), 0); + EXPECT_EQ(layout.evaluate({1, 0, 0}), 16); + EXPECT_EQ(layout.evaluate({0, 1, 0}), 1); + EXPECT_EQ(layout.evaluate({0, 0, 1}), 4); + EXPECT_EQ(layout.evaluate({1, 3, 7}), 47); // 1*16 + 3*1 + 7*4 +} + +TEST_F(PackMapTest, EvaluateColumnMajor) { + // (4, 8) : (1, 4) — column-major + auto layout = make({4, 8}, {1, 4}); + EXPECT_EQ(layout.evaluate({0, 0}), 0); + EXPECT_EQ(layout.evaluate({1, 0}), 1); + EXPECT_EQ(layout.evaluate({0, 1}), 4); + EXPECT_EQ(layout.evaluate({3, 7}), 31); +} + +TEST_F(PackMapTest, EvaluateFlatIndexNonIdentity) { + // Flat index on column-major (4, 8):(1, 4) exercises idx2crd + crd2idx. + // idx2crd(i, (4,8)) decomposes i row-major; crd2idx then applies col-major + // strides. + auto layout = make({4, 8}, {1, 4}); + EXPECT_EQ(layout.evaluate({0}), 0); // (0,0) -> 0 + EXPECT_EQ(layout.evaluate({1}), 4); // (0,1) -> 0*1+1*4 = 4 + EXPECT_EQ(layout.evaluate({5}), 20); // (0,5) -> 0*1+5*4 = 20 + EXPECT_EQ(layout.evaluate({8}), 1); // (1,0) -> 1*1+0*4 = 1 +} + +TEST_F(PackMapTest, EvaluateStrided) { + // (4):(3) — stride > 1, cosize (10) > size (4). + auto layout = make({4}, {3}); + EXPECT_EQ(layout.evaluate({0}), 0); + EXPECT_EQ(layout.evaluate({1}), 3); + EXPECT_EQ(layout.evaluate({3}), 9); + + // (4, 8):(2, 1) — outer stride is not shape*inner_stride. + auto layout2 = make({4, 8}, {2, 1}); + EXPECT_EQ(layout2.evaluate({0, 0}), 0); + EXPECT_EQ(layout2.evaluate({1, 0}), 2); + EXPECT_EQ(layout2.evaluate({0, 1}), 1); + EXPECT_EQ(layout2.evaluate({3, 7}), 13); // 3*2 + 7*1 +} + +TEST_F(PackMapTest, EvaluateBroadcast) { + // (4, 8):(0, 1) — stride-0 dim contributes nothing regardless of coord. + auto layout = make({4, 8}, {0, 1}); + EXPECT_EQ(layout.evaluate({0, 0}), 0); + EXPECT_EQ(layout.evaluate({3, 5}), 5); // row coord ignored + EXPECT_EQ(layout.evaluate({0, 7}), 7); +} + +//===----------------------------------------------------------------------===// +// Coalesce +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, CoalesceContiguous) { + // Nested: ((2, 4), 8) : ((32, 8), 1) — leaves (2,4,8):(32,8,1), all + // contiguous. Merge right-to-left: 8*1=8==8 -> (32,1), 32*1=32==32 -> (64,1). + auto layout = make({{2, 4}, 8}, {{32, 8}, 1}); + EXPECT_EQ(layout.coalesce(), make({64}, {1})); +} + +TEST_F(PackMapTest, CoalesceNonContiguous) { + // (4, 8) : (1, 4) — column-major, not contiguous in lex order + auto layout = make({4, 8}, {1, 4}); + auto coalesced = layout.coalesce(); + EXPECT_EQ(coalesced.getRank(), 2); + + // (4, 8) : (16, 1) — holes between groups: stride 16 > 8*1 = 8, so + // accShape*accStride (8) != 16. Stays at rank 2. cosize=56 > size=32. + auto holey = make({4, 8}, {16, 1}); + EXPECT_EQ(holey.coalesce(), make({4, 8}, {16, 1})); +} + +TEST_F(PackMapTest, CoalesceRemovesUnitModes) { + // (1, 8) : (0, 1) -> coalesced to (8) : (1) + auto layout = make({1, 8}, {0, 1}); + EXPECT_EQ(layout.coalesce(), make({8}, {1})); +} + +//===----------------------------------------------------------------------===// +// CoalesceModes +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, CoalesceModesMergesWithinMode) { + // (4, (2, 4)) : (8, (4, 1)) -- mode 0 is a leaf, mode 1 has contiguous + // sub-leaves and merges internally. + auto layout = make({4, {2, 4}}, {8, {4, 1}}); + EXPECT_EQ(layout.coalesceModes(), make({4, 8}, {8, 1})); +} + +TEST_F(PackMapTest, CoalesceModesVsCoalesce) { + // coalesce merges across mode boundaries; coalesceModes does not. + // (4, (2, 4)) : (8, (4, 1)): leaves (4:8, 2:4, 4:1) are all contiguous, + // so coalesce merges everything to (32):(1), while coalesceModes only + // merges mode 1 internally, preserving the boundary. + auto layout = make({4, {2, 4}}, {8, {4, 1}}); + EXPECT_EQ(layout.coalesce(), make({32}, {1})); + EXPECT_EQ(layout.coalesceModes(), make({4, 8}, {8, 1})); +} + +TEST_F(PackMapTest, CoalesceModesNonContiguousWithinMode) { + // (4, (2, 4)) : (8, (8, 1)) -- mode 1 sub-leaves (2,4):(8,1) are not + // contiguous (8 != 4*1=4), so mode 1 stays unchanged. + auto layout = make({4, {2, 4}}, {8, {8, 1}}); + EXPECT_EQ(layout.coalesceModes(), make({4, {2, 4}}, {8, {8, 1}})); +} + +TEST_F(PackMapTest, CoalesceModesRemovesUnitModes) { + // (1, (2, 4)) : (5, (4, 1)) -- mode 0 is size-1, stride normalized to 0; + // mode 1 merges to (8):(1). + auto layout = make({1, {2, 4}}, {5, {4, 1}}); + EXPECT_EQ(layout.coalesceModes(), make({1, 8}, {0, 1})); +} + +//===----------------------------------------------------------------------===// +// Flatten +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, FlattenHierarchical) { + // ((2, 4), 8) : ((16, 1), 4) + auto layout = make({{2, 4}, 8}, {{16, 1}, 4}); + EXPECT_EQ(layout.flatten(), make({2, 4, 8}, {16, 1, 4})); +} + +//===----------------------------------------------------------------------===// +// Compose +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, ComposeIdentity) { + // Nested LHS: ((2, 4), 4) : ((16, 4), 1) composed with identity. + // Result must equal A for all flat indices. + auto a = make({{2, 4}, 4}, {{16, 4}, 1}); + auto id = PackMapAttr::makeIdentity(getContext(), {32}); + auto result = a.compose(id); + for (int i = 0; i < 32; ++i) { + EXPECT_EQ(result.evaluate({i}), a.evaluate({i})); + } +} + +TEST_F(PackMapTest, ComposeWithBroadcast) { + // stride-0 in RHS should produce stride-0 in result. + auto a = make({8}, {1}); + auto rhs = make({4}, {0}); + auto result = a.compose(rhs); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(result.evaluate({i}), 0); + } +} + +TEST_F(PackMapTest, ComposeMultiDigit) { + // Compose column-major LHS with a stride-2 RHS. + // LHS (4, 8):(1, 4) maps flat index i -> (i%4)*1 + (i/4)*4. + // RHS (8):(2) maps index j -> 2*j. + // Composed: result(j) = LHS(2*j). + auto lhs = make({4, 8}, {1, 4}); + auto rhs = make({8}, {2}); + auto result = lhs.compose(rhs); + // Verify functionally: result(j) == lhs(2*j) for all j in [0, 8). + for (int j = 0; j < 8; ++j) { + EXPECT_EQ(result.evaluate({j}), lhs.evaluate({2 * j})); + } +} + +TEST_F(PackMapTest, ComposeMultiDigitRowMajor) { + // Compose row-major 4x8 with a 2x4 row-major layout. + // LHS = (4, 8):(8, 1), RHS = (2, 4):(4, 1). + // result(i, j) = LHS(RHS(i, j)) = LHS(4*i + j). + auto lhs = make({4, 8}, {8, 1}); + auto rhs = make({2, 4}, {4, 1}); + auto result = lhs.compose(rhs); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 4; ++j) { + int rhsIdx = 4 * i + j; + EXPECT_EQ(result.evaluate({i, j}), lhs.evaluate({rhsIdx})); + } + } +} + +//===----------------------------------------------------------------------===// +// Complement +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, ComplementBasic) { + // Nested ((2, 2)) : ((1, 2)) — leaves (2,2):(1,2), cotarget=16. + // Strides contiguous (acc 1->2->4), trailing 16/4=4 -> complement (4):(4). + auto layout = make({{2, 2}}, {{1, 2}}); + EXPECT_EQ(layout.complement(16), make({4}, {4})); +} + +TEST_F(PackMapTest, ComplementWithGap) { + // (4) : (2) with cotarget=8. + // Original covers {0, 2, 4, 6}. Complement fills the interleaving gap. + auto layout = make({4}, {2}); + EXPECT_EQ(layout.complement(8), make({2}, {1})); +} + +//===----------------------------------------------------------------------===// +// LogicalDivide +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, LogicalDivideBasic) { + // Nested layout ((4, 8)) : ((8, 1)) — same function as (32):(1). + // Divide by tiler (4):(1) -> inner=(4):(1), outer=(8):(4). + auto layout = make({{4, 8}}, {{8, 1}}); + auto tiler = make({4}, {1}); + // Each mode stores its sub-layout's shape as a 1-element tuple. + auto result = layout.logicalDivide(tiler); + EXPECT_EQ(result, make({{4}, {8}}, {{1}, {4}})); + // Functional: result(inner, outer) = inner + tileSize * outer. + for (int outer = 0; outer < 8; ++outer) { + for (int inner = 0; inner < 4; ++inner) { + EXPECT_EQ(result.evaluate({inner, outer}), inner + 4 * outer); + } + } +} + +//===----------------------------------------------------------------------===// +// LogicalProduct +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, LogicalProductBasic) { + // Nested layout ((2, 2)) : ((1, 2)) — same size=4 as (4):(1). + // Product with (8):(1) -> mode 0 size=4, mode 1 size=8. + auto layout = make({{2, 2}}, {{1, 2}}); + auto tiler = make({8}, {1}); + // Mode 0 stores original layout's full shape; mode 1 stores + // comp.compose(tiler). + EXPECT_EQ(layout.logicalProduct(tiler), + make({{{2, 2}}, {8}}, {{{1, 2}}, {4}})); +} + +//===----------------------------------------------------------------------===// +// Filter +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, FilterRemovesBroadcast) { + // Nested ((4, 8)) : ((0, 1)) — stride-0 leaf removed, leaves (8):(1). + auto layout = make({{4, 8}}, {{0, 1}}); + EXPECT_EQ(layout.filter(), make({8}, {1})); +} + +TEST_F(PackMapTest, FilterRemovesUnitSize) { + // (1, 8) : (5, 1) -> filter removes size-1 mode + auto layout = make({1, 8}, {5, 1}); + EXPECT_EQ(layout.filter(), make({8}, {1})); +} + +//===----------------------------------------------------------------------===// +// RightInverse +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, RightInverse) { + // Nested ((2, 4)) : ((4, 1)) — same function as (8):(1) (row-major). + auto layout = make({{2, 4}}, {{4, 1}}); + auto ri = layout.rightInverse(); + // A(R(x)) = x for all x in [0, 8) + for (int i = 0; i < 8; ++i) { + int64_t intermediate = ri.evaluate({i}); + EXPECT_EQ(layout.evaluate({intermediate}), i); + } +} + +TEST_F(PackMapTest, RightInverseColumnMajor) { + // (4, 2) : (1, 4) — column-major + auto layout = make({4, 2}, {1, 4}); + auto ri = layout.rightInverse(); + for (int i = 0; i < 8; ++i) { + int64_t intermediate = ri.evaluate({i}); + EXPECT_EQ(layout.evaluate({intermediate}), i); + } +} + +TEST_F(PackMapTest, RightInverseNonSurjective) { + // (4):(2) maps to {0, 2, 4, 6} — injective but stride starts at 2, not 1. + // rightInverse collects only contiguous strides starting from 1, so the + // result is trivial (1):(0) since no stride-1 leaf exists. + auto layout = make({4}, {2}); + EXPECT_EQ(layout.rightInverse(), make({1}, {0})); +} + +TEST_F(PackMapTest, LeftInverseNonSurjective) { + // (4):(2) maps to {0, 2, 4, 6}. leftInverse uses complement to fill gaps, + // then rightInverse of the combined layout. L(A(x)) = x for all x in [0, 4). + auto layout = make({4}, {2}); + auto li = layout.leftInverse(); + for (int i = 0; i < 4; ++i) { + int64_t output = layout.evaluate({i}); + EXPECT_EQ(li.evaluate({output}), i); + } +} + +//===----------------------------------------------------------------------===// +// LeftInverse +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, LeftInverse) { + // Nested ((2, 4)) : ((4, 1)) — same function as (8):(1). + auto layout = make({{2, 4}}, {{4, 1}}); + auto li = layout.leftInverse(); + // L(A(x)) = x for all x in [0, 8) + for (int i = 0; i < 8; ++i) { + int64_t intermediate = layout.evaluate({i}); + EXPECT_EQ(li.evaluate({intermediate}), i); + } +} + +//===----------------------------------------------------------------------===// +// TiledDivide / TiledProduct +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, TiledDivide) { + // Nested ((4, 8)) : ((8, 1)) — same function as (32):(1). + auto layout = make({{4, 8}}, {{8, 1}}); + auto tiler = make({4}, {1}); + // flattenRestModes unwraps mode 1's 1-element tuple into a leaf. + EXPECT_EQ(layout.tiledDivide(tiler), make({{4}, 8}, {{1}, 4})); +} + +TEST_F(PackMapTest, TiledProduct) { + auto layout = make({4}, {1}); + auto tiler = make({8}, {1}); + EXPECT_EQ(layout.tiledProduct(tiler), make({{4}, 8}, {{1}, 4})); +} + +//===----------------------------------------------------------------------===// +// Permute +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, Permute) { + // mode 0 = (2, 2):(4, 1), mode 1 = 8:8 -> permute({1,0}) swaps them. + auto layout = make({{2, 2}, 8}, {{4, 1}, 8}); + EXPECT_EQ(layout.permute({1, 0}), make({8, {2, 2}}, {8, {4, 1}})); +} + +//===----------------------------------------------------------------------===// +// Project +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, Project) { + // mode 0 = (2, 2):(4, 1), mode 1 = 8:2, mode 2 = 2:1 -> drop mode 1. + auto layout = make({{2, 2}, 8, 2}, {{4, 1}, 2, 1}); + EXPECT_EQ(layout.project({false, true, false}), + make({{2, 2}, 2}, {{4, 1}, 1})); +} + +//===----------------------------------------------------------------------===// +// MakeIdentity +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, MakeIdentity) { + auto id = PackMapAttr::makeIdentity(getContext(), {4, 8}); + EXPECT_EQ(id, make({4, 8}, {8, 1})); + for (int i = 0; i < 32; ++i) { + EXPECT_EQ(id.evaluate({i}), i); + } +} + +TEST_F(PackMapTest, MakeIdentity3D) { + EXPECT_EQ(PackMapAttr::makeIdentity(getContext(), {2, 3, 4}), + make({2, 3, 4}, {12, 4, 1})); +} + +//===----------------------------------------------------------------------===// +// IntTuple: filterLeafInfos and foldLeafInfos +//===----------------------------------------------------------------------===// + +TEST_F(PackMapTest, FilterLeafInfosZeroStride) { + // (4, 8) : (0, 1) -> zero-stride leaves: LeafInfo{4, 0, 8} + auto layout = make({4, 8}, {0, 1}); + auto zeroStride = + filterLeafInfos(layout.getShape(), layout.getStride(), + [](const LeafInfo &l) { return l.stride == 0; }); + EXPECT_EQ(zeroStride.size(), 1u); + EXPECT_EQ(zeroStride[0].size, 4); +} + +TEST_F(PackMapTest, FilterLeafInfosNonZeroStride) { + auto layout = make({4, 8}, {2, 1}); + auto nonZero = + filterLeafInfos(layout.getShape(), layout.getStride(), + [](const LeafInfo &l) { return l.stride > 0; }); + EXPECT_EQ(nonZero.size(), 2u); +} + +TEST_F(PackMapTest, FoldLeafInfosProductOfSizes) { + // Product of sizes for stride > 0 leaves. + auto layout = make({4, 8}, {2, 1}); + int64_t product = foldLeafInfos(layout.getShape(), layout.getStride(), 1, + [](int64_t acc, const LeafInfo &l) { + return l.stride > 0 ? acc * l.size : acc; + }); + EXPECT_EQ(product, 32); // 4 * 8 +} + +TEST_F(PackMapTest, FoldLeafInfosWithBroadcast) { + // (4, 8) : (0, 1) -> product of stride>0 sizes = 8 + auto layout = make({4, 8}, {0, 1}); + int64_t product = foldLeafInfos(layout.getShape(), layout.getStride(), 1, + [](int64_t acc, const LeafInfo &l) { + return l.stride > 0 ? acc * l.size : acc; + }); + EXPECT_EQ(product, 8); +} + +} // namespace +} // namespace mlir::iree_compiler::IREE::Map diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 7cd7b00b6937..e67d6713bb1a 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -37,6 +37,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/Map/IR:IREEMapDialect", "//compiler/src/iree/compiler/Codegen/Dialect/PCF/IR", "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "//compiler/src/iree/compiler/Codegen/Interfaces", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index a6d2d3d709f0..9b40918f9551 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -31,6 +31,7 @@ iree_cc_library( iree::compiler::Bindings::TFLite::Transforms iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Dialect::Map::IR::IREEMapDialect iree::compiler::Codegen::Dialect::PCF::IR iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect iree::compiler::Codegen::Interfaces::Interfaces diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h index c47ae6cb4368..4abbc699bbe9 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h @@ -16,6 +16,7 @@ #include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h" #include "iree/compiler/Codegen/Dialect/PCF/IR/PCFDialect.h" #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Interfaces/Interfaces.h" @@ -45,6 +46,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { IREE::Codegen::IREECodegenDialect, IREE::Flow::FlowDialect, IREE::GPU::IREEGPUDialect, + IREE::Map::IREEMapDialect, IREE::HAL::HALDialect, IREE::HAL::Inline::HALInlineDialect, IREE::HAL::Loader::HALLoaderDialect,