Skip to content

Commit e9201c2

Browse files
Groverkssclaude
andauthored
[Codegen] Add PackLayoutAttr and VectorLayoutInterface to iree_map dialect (#23672)
Adds `PackLayoutAttr` to the `iree_map` dialect — a wrapper around `PackMapAttr` that stores the map in by-mode-coalesced form and implements `VectorLayoutInterface` for use in vector distribution. Includes the ExternalInterfaces library wiring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7649e20 commit e9201c2

17 files changed

Lines changed: 797 additions & 0 deletions
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2026 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
8+
9+
package(
10+
default_visibility = ["//visibility:public"],
11+
features = ["layering_check"],
12+
licenses = ["notice"], # Apache 2.0
13+
)
14+
15+
iree_compiler_cc_library(
16+
name = "ExternalModels",
17+
srcs = [
18+
"Interfaces.cpp",
19+
"VectorLayoutInterfaceImpl.cpp",
20+
],
21+
hdrs = [
22+
"Interfaces.h",
23+
"VectorLayoutInterfaceImpl.h",
24+
],
25+
deps = [
26+
"//compiler/src/iree/compiler/Codegen/Dialect/Map/IR:IREEMapDialect",
27+
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
28+
"@llvm-project//llvm:Support",
29+
"@llvm-project//mlir:IR",
30+
"@llvm-project//mlir:Support",
31+
],
32+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
################################################################################
2+
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
3+
# compiler/src/iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/BUILD.bazel#
4+
# #
5+
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
6+
# CMake-only content. #
7+
# #
8+
# To disable autogeneration for this file entirely, delete this header. #
9+
################################################################################
10+
11+
iree_add_all_subdirs()
12+
13+
iree_cc_library(
14+
NAME
15+
ExternalModels
16+
HDRS
17+
"Interfaces.h"
18+
"VectorLayoutInterfaceImpl.h"
19+
SRCS
20+
"Interfaces.cpp"
21+
"VectorLayoutInterfaceImpl.cpp"
22+
DEPS
23+
LLVMSupport
24+
MLIRIR
25+
MLIRSupport
26+
iree::compiler::Codegen::Dialect::Map::IR::IREEMapDialect
27+
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
28+
PUBLIC
29+
)
30+
31+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/Interfaces.h"
8+
9+
#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/VectorLayoutInterfaceImpl.h"
10+
11+
namespace mlir::iree_compiler {
12+
13+
void registerIREEMapExternalInterfaces(DialectRegistry &registry) {
14+
IREE::Map::registerVectorLayoutInterfaceExternalModels(registry);
15+
}
16+
17+
} // namespace mlir::iree_compiler
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_INTERFACES_H_
8+
#define IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_INTERFACES_H_
9+
10+
#include "mlir/IR/DialectRegistry.h"
11+
12+
namespace mlir::iree_compiler {
13+
14+
/// Registers all external interface implementations for the IREE Map dialect.
15+
void registerIREEMapExternalInterfaces(DialectRegistry &registry);
16+
17+
} // namespace mlir::iree_compiler
18+
19+
#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_INTERFACES_H_
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/VectorLayoutInterfaceImpl.h"
8+
9+
#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h"
10+
#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h"
11+
#include "iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h"
12+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.h"
13+
#include "mlir/IR/AffineMap.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
16+
namespace mlir::iree_compiler::IREE::Map {
17+
18+
using VectorExt::VectorLayoutInterface;
19+
20+
namespace {
21+
22+
struct PackLayoutModel final
23+
: VectorLayoutInterface::ExternalModel<PackLayoutModel, PackLayoutAttr> {
24+
int64_t getRank(Attribute attr) const {
25+
return cast<PackLayoutAttr>(attr).getRank();
26+
}
27+
28+
SmallVector<int64_t> getUndistributedShape(Attribute attr) const {
29+
return llvm::map_to_vector(cast<PackLayoutAttr>(attr).getShapeModes(),
30+
[](Attribute mode) { return getSize(mode); });
31+
}
32+
33+
SmallVector<int64_t> getDistributedShape(Attribute attr) const {
34+
auto layout = cast<PackLayoutAttr>(attr);
35+
SmallVector<int64_t> shape;
36+
for (auto [shapeMode, strideMode] :
37+
llvm::zip_equal(layout.getShapeModes(), layout.getStrideModes())) {
38+
// Stride-0 leaves represent per-thread value dimensions (broadcast).
39+
SmallVector<LeafInfo> valLeaves =
40+
filterLeafInfos(shapeMode, strideMode,
41+
[](const LeafInfo &l) { return l.stride == 0; });
42+
if (valLeaves.empty()) {
43+
shape.push_back(1);
44+
continue;
45+
}
46+
for (auto &leaf : valLeaves) {
47+
shape.push_back(static_cast<int64_t>(leaf.size));
48+
}
49+
}
50+
return shape;
51+
}
52+
53+
LogicalResult isValidLayout(Attribute attr, ShapedType shapeTy,
54+
Location loc) const {
55+
auto layout = cast<PackLayoutAttr>(attr);
56+
int64_t rank = layout.getRank();
57+
ArrayRef<int64_t> vecShape = shapeTy.getShape();
58+
if (static_cast<int64_t>(vecShape.size()) != rank) {
59+
return emitError(loc, "Rank of vector (")
60+
<< vecShape.size() << ") does not match rank of layout (" << rank
61+
<< ").";
62+
}
63+
if (isa<RankedTensorType>(shapeTy)) {
64+
// Allow layout size to exceed tensor size for padding/masking.
65+
return success();
66+
}
67+
SmallVector<int64_t> expected = getUndistributedShape(attr);
68+
for (auto [i, vecDim, expDim] : llvm::enumerate(vecShape, expected)) {
69+
if (ShapedType::isStatic(vecDim) && expDim != vecDim) {
70+
return emitError(loc, "Vector shape mismatch at dim ")
71+
<< i << ": expected " << expDim << ", got " << vecDim;
72+
}
73+
}
74+
return success();
75+
}
76+
77+
VectorLayoutInterface permute(Attribute attr, ArrayRef<int64_t> perm) const {
78+
return VectorLayoutInterface(cast<PackLayoutAttr>(attr).permute(perm));
79+
}
80+
81+
VectorLayoutInterface project(Attribute attr,
82+
ArrayRef<bool> droppedDims) const {
83+
return VectorLayoutInterface(
84+
cast<PackLayoutAttr>(attr).project(droppedDims));
85+
}
86+
87+
VectorLayoutInterface apply(Attribute attr, AffineMap map) const {
88+
auto layout = cast<PackLayoutAttr>(attr);
89+
MLIRContext *ctx = attr.getContext();
90+
int64_t numResults = map.getNumResults();
91+
92+
SmallVector<Attribute> modeShapes(numResults);
93+
SmallVector<Attribute> modeStrides(numResults);
94+
95+
for (auto [idx, expr] : llvm::enumerate(map.getResults())) {
96+
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
97+
int64_t pos = dim.getPosition();
98+
modeShapes[idx] = layout.getShapeMode(pos);
99+
modeStrides[idx] = layout.getStrideMode(pos);
100+
continue;
101+
}
102+
// Non-dim expressions (constants, adds, etc.) lose layout info.
103+
modeShapes[idx] = makeLeaf(ctx, 1);
104+
modeStrides[idx] = makeLeaf(ctx, 0);
105+
}
106+
107+
return VectorLayoutInterface(PackLayoutAttr::get(
108+
ctx, makeTuple(ctx, modeShapes), makeTuple(ctx, modeStrides)));
109+
}
110+
111+
VectorLayoutInterface reshape(Attribute attr,
112+
ArrayRef<int64_t> newShape) const {
113+
auto layout = cast<PackLayoutAttr>(attr);
114+
auto newShapeId =
115+
PackMapAttr::makeIdentity(layout.getMap().getContext(), newShape);
116+
return VectorLayoutInterface(PackLayoutAttr::get(
117+
attr.getContext(), layout.getMap().compose(newShapeId)));
118+
}
119+
120+
bool
121+
needsSharedMemoryForConversion(Attribute attr,
122+
VectorLayoutInterface targetLayout) const {
123+
auto targetLayoutAttr = dyn_cast_if_present<PackLayoutAttr>(targetLayout);
124+
if (!targetLayoutAttr) {
125+
return true;
126+
}
127+
auto srcLayoutAttr = cast<PackLayoutAttr>(attr);
128+
return srcLayoutAttr.coalesce() != targetLayoutAttr.coalesce();
129+
}
130+
131+
static VectorLayoutInterface
132+
getRecombinedLayout(ArrayRef<VectorLayoutInterface> layouts,
133+
ArrayRef<AffineMap> maps, AffineMap resultMap) {
134+
if (!llvm::all_of(layouts, llvm::IsaPred<PackLayoutAttr>)) {
135+
return VectorLayoutInterface();
136+
}
137+
MLIRContext *ctx = resultMap.getContext();
138+
139+
SmallVector<PackLayoutAttr> packLayouts =
140+
llvm::map_to_vector(layouts, llvm::CastTo<PackLayoutAttr>);
141+
142+
int64_t resRank = resultMap.getNumResults();
143+
144+
// Null Attribute serves as "not yet assigned" sentinel.
145+
Attribute unset;
146+
SmallVector<Attribute> modeShapes(resRank, unset);
147+
SmallVector<Attribute> modeStrides(resRank, unset);
148+
149+
for (auto [layout, indexingMap] : llvm::zip_equal(packLayouts, maps)) {
150+
for (int64_t resultIdx = 0;
151+
resultIdx < static_cast<int64_t>(indexingMap.getNumResults());
152+
++resultIdx) {
153+
auto dimExpr =
154+
dyn_cast<AffineDimExpr>(indexingMap.getResult(resultIdx));
155+
if (!dimExpr) {
156+
continue;
157+
}
158+
int64_t iterPos = dimExpr.getPosition();
159+
auto maybeResPos =
160+
resultMap.getResultPosition(getAffineDimExpr(iterPos, ctx));
161+
if (!maybeResPos.has_value()) {
162+
continue;
163+
}
164+
int64_t resPos = maybeResPos.value();
165+
166+
Attribute shape = layout.getShapeMode(resultIdx);
167+
Attribute stride = layout.getStrideMode(resultIdx);
168+
169+
if (modeShapes[resPos] && modeShapes[resPos] != shape) {
170+
return VectorLayoutInterface();
171+
}
172+
modeShapes[resPos] = shape;
173+
modeStrides[resPos] = stride;
174+
}
175+
}
176+
177+
for (auto [shape, stride] : llvm::zip_equal(modeShapes, modeStrides)) {
178+
if (!shape) {
179+
shape = makeLeaf(ctx, 1);
180+
stride = makeLeaf(ctx, 0);
181+
}
182+
}
183+
184+
return VectorLayoutInterface(PackLayoutAttr::get(
185+
ctx, makeTuple(ctx, modeShapes), makeTuple(ctx, modeStrides)));
186+
}
187+
};
188+
189+
} // namespace
190+
191+
void registerVectorLayoutInterfaceExternalModels(DialectRegistry &registry) {
192+
registry.addExtension(+[](MLIRContext *context, IREEMapDialect *dialect) {
193+
PackLayoutAttr::attachInterface<PackLayoutModel>(*context);
194+
});
195+
}
196+
197+
} // namespace mlir::iree_compiler::IREE::Map
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_VECTORLAYOUTINTERFACEIMPL_H_
8+
#define IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_VECTORLAYOUTINTERFACEIMPL_H_
9+
10+
#include "mlir/IR/DialectRegistry.h"
11+
12+
namespace mlir::iree_compiler::IREE::Map {
13+
14+
void registerVectorLayoutInterfaceExternalModels(DialectRegistry &registry);
15+
16+
} // namespace mlir::iree_compiler::IREE::Map
17+
18+
#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_VECTORLAYOUTINTERFACEIMPL_H_
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2026 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
8+
9+
package(
10+
features = ["layering_check"],
11+
licenses = ["notice"], # Apache 2.0
12+
)
13+
14+
iree_lit_test_suite(
15+
name = "lit",
16+
srcs = [
17+
"pack_layout_vector_analysis.mlir",
18+
],
19+
cfg = "//compiler:lit.cfg.py",
20+
tools = [
21+
"//tools:iree-opt",
22+
"@llvm-project//llvm:FileCheck",
23+
],
24+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
################################################################################
2+
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
3+
# compiler/src/iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/test/BUILD.bazel#
4+
# #
5+
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
6+
# CMake-only content. #
7+
# #
8+
# To disable autogeneration for this file entirely, delete this header. #
9+
################################################################################
10+
11+
iree_add_all_subdirs()
12+
13+
iree_lit_test_suite(
14+
NAME
15+
lit
16+
SRCS
17+
"pack_layout_vector_analysis.mlir"
18+
TOOLS
19+
FileCheck
20+
iree-opt
21+
)
22+
23+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###

0 commit comments

Comments
 (0)