|
| 1 | +// Copyright 2026 The IREE Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/VectorLayoutInterfaceImpl.h" |
| 8 | + |
| 9 | +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h" |
| 10 | +#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h" |
| 11 | +#include "iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h" |
| 12 | +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.h" |
| 13 | +#include "mlir/IR/AffineMap.h" |
| 14 | +#include "mlir/IR/BuiltinTypes.h" |
| 15 | + |
| 16 | +namespace mlir::iree_compiler::IREE::Map { |
| 17 | + |
| 18 | +using VectorExt::VectorLayoutInterface; |
| 19 | + |
| 20 | +namespace { |
| 21 | + |
| 22 | +struct PackLayoutModel final |
| 23 | + : public VectorLayoutInterface::ExternalModel<PackLayoutModel, |
| 24 | + PackLayoutAttr> { |
| 25 | + int64_t getRank(Attribute attr) const { |
| 26 | + return cast<PackLayoutAttr>(attr).getRank(); |
| 27 | + } |
| 28 | + |
| 29 | + SmallVector<int64_t> getUndistributedShape(Attribute attr) const { |
| 30 | + auto layout = cast<PackLayoutAttr>(attr); |
| 31 | + SmallVector<int64_t> shape; |
| 32 | + int32_t rank = layout.getRank(); |
| 33 | + shape.reserve(rank); |
| 34 | + for (int32_t i = 0; i < rank; ++i) { |
| 35 | + shape.push_back(getSize(layout.getShapeMode(i))); |
| 36 | + } |
| 37 | + return shape; |
| 38 | + } |
| 39 | + |
| 40 | + SmallVector<int64_t> getDistributedShape(Attribute attr) const { |
| 41 | + auto layout = cast<PackLayoutAttr>(attr); |
| 42 | + SmallVector<int64_t> shape; |
| 43 | + int32_t rank = layout.getRank(); |
| 44 | + shape.reserve(rank); |
| 45 | + for (int32_t i = 0; i < rank; ++i) { |
| 46 | + // Stride-0 leaves represent per-thread value dimensions (broadcast). |
| 47 | + SmallVector<LeafInfo> valLeaves = |
| 48 | + filterLeafInfos(layout.getShapeMode(i), layout.getStrideMode(i), |
| 49 | + [](const LeafInfo &l) { return l.stride == 0; }); |
| 50 | + if (valLeaves.empty()) { |
| 51 | + shape.push_back(1); |
| 52 | + } else { |
| 53 | + for (auto &leaf : valLeaves) { |
| 54 | + shape.push_back(static_cast<int64_t>(leaf.size)); |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | + return shape; |
| 59 | + } |
| 60 | + |
| 61 | + LogicalResult isValidLayout(Attribute attr, ShapedType shapeTy, |
| 62 | + Location loc) const { |
| 63 | + auto layout = cast<PackLayoutAttr>(attr); |
| 64 | + int64_t rank = layout.getRank(); |
| 65 | + ArrayRef<int64_t> vecShape = shapeTy.getShape(); |
| 66 | + if (static_cast<int64_t>(vecShape.size()) != rank) { |
| 67 | + return emitError(loc, "Rank of vector (") |
| 68 | + << vecShape.size() << ") does not match rank of layout (" << rank |
| 69 | + << ")."; |
| 70 | + } |
| 71 | + SmallVector<int64_t> expected = getUndistributedShape(attr); |
| 72 | + for (int64_t i = 0; i < rank; ++i) { |
| 73 | + if (ShapedType::isStatic(vecShape[i]) && expected[i] != vecShape[i]) { |
| 74 | + return emitError(loc, "Vector shape mismatch at dim ") |
| 75 | + << i << ": expected " << expected[i] << ", got " << vecShape[i]; |
| 76 | + } |
| 77 | + } |
| 78 | + return success(); |
| 79 | + } |
| 80 | + |
| 81 | + VectorLayoutInterface permute(Attribute attr, ArrayRef<int64_t> perm) const { |
| 82 | + return VectorLayoutInterface(cast<PackLayoutAttr>(attr).permute(perm)); |
| 83 | + } |
| 84 | + |
| 85 | + VectorLayoutInterface project(Attribute attr, |
| 86 | + ArrayRef<bool> droppedDims) const { |
| 87 | + return VectorLayoutInterface( |
| 88 | + cast<PackLayoutAttr>(attr).project(droppedDims)); |
| 89 | + } |
| 90 | + |
| 91 | + VectorLayoutInterface apply(Attribute attr, AffineMap map) const { |
| 92 | + auto layout = cast<PackLayoutAttr>(attr); |
| 93 | + MLIRContext *ctx = attr.getContext(); |
| 94 | + int64_t numResults = map.getNumResults(); |
| 95 | + |
| 96 | + SmallVector<Attribute> modeShapes(numResults); |
| 97 | + SmallVector<Attribute> modeStrides(numResults); |
| 98 | + |
| 99 | + for (auto [idx, expr] : llvm::enumerate(map.getResults())) { |
| 100 | + if (auto dim = dyn_cast<AffineDimExpr>(expr)) { |
| 101 | + int64_t pos = dim.getPosition(); |
| 102 | + modeShapes[idx] = layout.getShapeMode(pos); |
| 103 | + modeStrides[idx] = layout.getStrideMode(pos); |
| 104 | + } else { |
| 105 | + // Non-dim expressions (constants, adds, etc.) lose layout info. |
| 106 | + modeShapes[idx] = makeLeaf(ctx, 1); |
| 107 | + modeStrides[idx] = makeLeaf(ctx, 0); |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + return VectorLayoutInterface(PackLayoutAttr::get( |
| 112 | + ctx, makeTuple(ctx, modeShapes), makeTuple(ctx, modeStrides))); |
| 113 | + } |
| 114 | + |
| 115 | + VectorLayoutInterface reshape(Attribute attr, |
| 116 | + ArrayRef<int64_t> newShape) const { |
| 117 | + auto layout = cast<PackLayoutAttr>(attr); |
| 118 | + auto newShapeId = |
| 119 | + PackMapAttr::makeIdentity(layout.getMap().getContext(), newShape); |
| 120 | + return VectorLayoutInterface(PackLayoutAttr::get( |
| 121 | + attr.getContext(), layout.getMap().compose(newShapeId))); |
| 122 | + } |
| 123 | + |
| 124 | + bool |
| 125 | + needsSharedMemoryForConversion(Attribute attr, |
| 126 | + VectorLayoutInterface targetLayout) const { |
| 127 | + auto targetLayoutAttr = dyn_cast_if_present<PackLayoutAttr>(targetLayout); |
| 128 | + if (!targetLayoutAttr) { |
| 129 | + return true; |
| 130 | + } |
| 131 | + auto srcLayoutAttr = cast<PackLayoutAttr>(attr); |
| 132 | + return srcLayoutAttr.coalesce() != targetLayoutAttr.coalesce(); |
| 133 | + } |
| 134 | + |
| 135 | + static VectorLayoutInterface |
| 136 | + getRecombinedLayout(ArrayRef<VectorLayoutInterface> layouts, |
| 137 | + ArrayRef<AffineMap> maps, AffineMap resultMap) { |
| 138 | + if (!llvm::all_of(layouts, llvm::IsaPred<PackLayoutAttr>)) { |
| 139 | + return VectorLayoutInterface(); |
| 140 | + } |
| 141 | + MLIRContext *ctx = resultMap.getContext(); |
| 142 | + |
| 143 | + SmallVector<PackLayoutAttr> packLayouts; |
| 144 | + llvm::transform(layouts, std::back_inserter(packLayouts), |
| 145 | + [](VectorLayoutInterface layout) { |
| 146 | + return cast<PackLayoutAttr>(layout); |
| 147 | + }); |
| 148 | + |
| 149 | + int64_t resRank = resultMap.getNumResults(); |
| 150 | + |
| 151 | + // Null Attribute serves as "not yet assigned" sentinel. |
| 152 | + Attribute unset; |
| 153 | + SmallVector<Attribute> modeShapes(resRank, unset); |
| 154 | + SmallVector<Attribute> modeStrides(resRank, unset); |
| 155 | + |
| 156 | + for (auto [layout, indexingMap] : llvm::zip(packLayouts, maps)) { |
| 157 | + for (int64_t resultIdx = 0; |
| 158 | + resultIdx < static_cast<int64_t>(indexingMap.getNumResults()); |
| 159 | + ++resultIdx) { |
| 160 | + auto dimExpr = |
| 161 | + dyn_cast<AffineDimExpr>(indexingMap.getResult(resultIdx)); |
| 162 | + if (!dimExpr) { |
| 163 | + continue; |
| 164 | + } |
| 165 | + int64_t iterPos = dimExpr.getPosition(); |
| 166 | + auto maybeResPos = |
| 167 | + resultMap.getResultPosition(getAffineDimExpr(iterPos, ctx)); |
| 168 | + if (!maybeResPos.has_value()) { |
| 169 | + continue; |
| 170 | + } |
| 171 | + int64_t resPos = maybeResPos.value(); |
| 172 | + |
| 173 | + Attribute ms = layout.getShapeMode(resultIdx); |
| 174 | + Attribute md = layout.getStrideMode(resultIdx); |
| 175 | + |
| 176 | + if (modeShapes[resPos] && modeShapes[resPos] != ms) { |
| 177 | + return VectorLayoutInterface(); |
| 178 | + } |
| 179 | + modeShapes[resPos] = ms; |
| 180 | + modeStrides[resPos] = md; |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + for (int64_t i = 0; i < resRank; ++i) { |
| 185 | + if (!modeShapes[i]) { |
| 186 | + modeShapes[i] = makeLeaf(ctx, 1); |
| 187 | + modeStrides[i] = makeLeaf(ctx, 0); |
| 188 | + } |
| 189 | + } |
| 190 | + |
| 191 | + return VectorLayoutInterface(PackLayoutAttr::get( |
| 192 | + ctx, makeTuple(ctx, modeShapes), makeTuple(ctx, modeStrides))); |
| 193 | + } |
| 194 | +}; |
| 195 | + |
| 196 | +} // namespace |
| 197 | + |
| 198 | +void registerVectorLayoutInterfaceExternalModels(DialectRegistry ®istry) { |
| 199 | + registry.addExtension(+[](MLIRContext *context, IREEMapDialect *dialect) { |
| 200 | + PackLayoutAttr::attachInterface<PackLayoutModel>(*context); |
| 201 | + }); |
| 202 | +} |
| 203 | + |
| 204 | +} // namespace mlir::iree_compiler::IREE::Map |
0 commit comments