Skip to content

Commit b004227

Browse files
Groverkssclaude
andcommitted
[Codegen] Add PackLayoutAttr and VectorLayoutInterface to iree_map dialect
Adds `PackLayoutAttr` to the `iree_map` dialect — a wrapper around `PackMapAttr` that stores the map in by-mode-coalesced form and implements `VectorLayoutInterface` for use in vector distribution. Includes the ExternalInterfaces library wiring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8e74bd2 commit b004227

16 files changed

Lines changed: 873 additions & 0 deletions
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2026 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
8+
9+
package(
10+
default_visibility = ["//visibility:public"],
11+
features = ["layering_check"],
12+
licenses = ["notice"], # Apache 2.0
13+
)
14+
15+
iree_compiler_cc_library(
16+
name = "ExternalModels",
17+
srcs = [
18+
"Interfaces.cpp",
19+
"VectorLayoutInterfaceImpl.cpp",
20+
],
21+
hdrs = [
22+
"Interfaces.h",
23+
"VectorLayoutInterfaceImpl.h",
24+
],
25+
deps = [
26+
"//compiler/src/iree/compiler/Codegen/Dialect/Map/IR:IREEMapDialect",
27+
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
28+
"@llvm-project//llvm:Support",
29+
"@llvm-project//mlir:IR",
30+
"@llvm-project//mlir:Support",
31+
],
32+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
################################################################################
2+
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
3+
# compiler/src/iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/BUILD.bazel#
4+
# #
5+
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
6+
# CMake-only content. #
7+
# #
8+
# To disable autogeneration for this file entirely, delete this header. #
9+
################################################################################
10+
11+
iree_add_all_subdirs()
12+
13+
iree_cc_library(
14+
NAME
15+
ExternalModels
16+
HDRS
17+
"Interfaces.h"
18+
"VectorLayoutInterfaceImpl.h"
19+
SRCS
20+
"Interfaces.cpp"
21+
"VectorLayoutInterfaceImpl.cpp"
22+
DEPS
23+
LLVMSupport
24+
MLIRIR
25+
MLIRSupport
26+
iree::compiler::Codegen::Dialect::Map::IR::IREEMapDialect
27+
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
28+
PUBLIC
29+
)
30+
31+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/Interfaces.h"
8+
9+
#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/VectorLayoutInterfaceImpl.h"
10+
11+
namespace mlir::iree_compiler {
12+
13+
void registerIREEMapExternalInterfaces(DialectRegistry &registry) {
14+
IREE::Map::registerVectorLayoutInterfaceExternalModels(registry);
15+
}
16+
17+
} // namespace mlir::iree_compiler
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_INTERFACES_H_
8+
#define IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_INTERFACES_H_
9+
10+
#include "mlir/IR/DialectRegistry.h"
11+
12+
namespace mlir::iree_compiler {
13+
14+
/// Registers all external interface implementations for the IREE Map dialect.
15+
void registerIREEMapExternalInterfaces(DialectRegistry &registry);
16+
17+
} // namespace mlir::iree_compiler
18+
19+
#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_INTERFACES_H_
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/VectorLayoutInterfaceImpl.h"
8+
9+
#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapAttrs.h"
10+
#include "iree/compiler/Codegen/Dialect/Map/IR/IREEMapDialect.h"
11+
#include "iree/compiler/Codegen/Dialect/Map/IR/IntTuple.h"
12+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.h"
13+
#include "mlir/IR/AffineMap.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
16+
namespace mlir::iree_compiler::IREE::Map {
17+
18+
using VectorExt::VectorLayoutInterface;
19+
20+
namespace {
21+
22+
struct PackLayoutModel final
23+
: public VectorLayoutInterface::ExternalModel<PackLayoutModel,
24+
PackLayoutAttr> {
25+
int64_t getRank(Attribute attr) const {
26+
return cast<PackLayoutAttr>(attr).getRank();
27+
}
28+
29+
SmallVector<int64_t> getUndistributedShape(Attribute attr) const {
30+
auto layout = cast<PackLayoutAttr>(attr);
31+
SmallVector<int64_t> shape;
32+
int32_t rank = layout.getRank();
33+
shape.reserve(rank);
34+
for (int32_t i = 0; i < rank; ++i) {
35+
shape.push_back(getSize(layout.getShapeMode(i)));
36+
}
37+
return shape;
38+
}
39+
40+
SmallVector<int64_t> getDistributedShape(Attribute attr) const {
41+
auto layout = cast<PackLayoutAttr>(attr);
42+
SmallVector<int64_t> shape;
43+
int32_t rank = layout.getRank();
44+
shape.reserve(rank);
45+
for (int32_t i = 0; i < rank; ++i) {
46+
// Stride-0 leaves represent per-thread value dimensions (broadcast).
47+
SmallVector<LeafInfo> valLeaves =
48+
filterLeafInfos(layout.getShapeMode(i), layout.getStrideMode(i),
49+
[](const LeafInfo &l) { return l.stride == 0; });
50+
if (valLeaves.empty()) {
51+
shape.push_back(1);
52+
} else {
53+
for (auto &leaf : valLeaves) {
54+
shape.push_back(static_cast<int64_t>(leaf.size));
55+
}
56+
}
57+
}
58+
return shape;
59+
}
60+
61+
LogicalResult isValidLayout(Attribute attr, ShapedType shapeTy,
62+
Location loc) const {
63+
auto layout = cast<PackLayoutAttr>(attr);
64+
int64_t rank = layout.getRank();
65+
ArrayRef<int64_t> vecShape = shapeTy.getShape();
66+
if (static_cast<int64_t>(vecShape.size()) != rank) {
67+
return emitError(loc, "Rank of vector (")
68+
<< vecShape.size() << ") does not match rank of layout (" << rank
69+
<< ").";
70+
}
71+
SmallVector<int64_t> expected = getUndistributedShape(attr);
72+
for (int64_t i = 0; i < rank; ++i) {
73+
if (ShapedType::isStatic(vecShape[i]) && expected[i] != vecShape[i]) {
74+
return emitError(loc, "Vector shape mismatch at dim ")
75+
<< i << ": expected " << expected[i] << ", got " << vecShape[i];
76+
}
77+
}
78+
return success();
79+
}
80+
81+
VectorLayoutInterface permute(Attribute attr, ArrayRef<int64_t> perm) const {
82+
return VectorLayoutInterface(cast<PackLayoutAttr>(attr).permute(perm));
83+
}
84+
85+
VectorLayoutInterface project(Attribute attr,
86+
ArrayRef<bool> droppedDims) const {
87+
return VectorLayoutInterface(
88+
cast<PackLayoutAttr>(attr).project(droppedDims));
89+
}
90+
91+
VectorLayoutInterface apply(Attribute attr, AffineMap map) const {
92+
auto layout = cast<PackLayoutAttr>(attr);
93+
MLIRContext *ctx = attr.getContext();
94+
int64_t numResults = map.getNumResults();
95+
96+
SmallVector<Attribute> modeShapes(numResults);
97+
SmallVector<Attribute> modeStrides(numResults);
98+
99+
for (auto [idx, expr] : llvm::enumerate(map.getResults())) {
100+
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
101+
int64_t pos = dim.getPosition();
102+
modeShapes[idx] = layout.getShapeMode(pos);
103+
modeStrides[idx] = layout.getStrideMode(pos);
104+
} else {
105+
// Non-dim expressions (constants, adds, etc.) lose layout info.
106+
modeShapes[idx] = makeLeaf(ctx, 1);
107+
modeStrides[idx] = makeLeaf(ctx, 0);
108+
}
109+
}
110+
111+
return VectorLayoutInterface(PackLayoutAttr::get(
112+
ctx, makeTuple(ctx, modeShapes), makeTuple(ctx, modeStrides)));
113+
}
114+
115+
VectorLayoutInterface reshape(Attribute attr,
116+
ArrayRef<int64_t> newShape) const {
117+
auto layout = cast<PackLayoutAttr>(attr);
118+
auto newShapeId =
119+
PackMapAttr::makeIdentity(layout.getMap().getContext(), newShape);
120+
return VectorLayoutInterface(PackLayoutAttr::get(
121+
attr.getContext(), layout.getMap().compose(newShapeId)));
122+
}
123+
124+
bool
125+
needsSharedMemoryForConversion(Attribute attr,
126+
VectorLayoutInterface targetLayout) const {
127+
auto targetLayoutAttr = dyn_cast_if_present<PackLayoutAttr>(targetLayout);
128+
if (!targetLayoutAttr) {
129+
return true;
130+
}
131+
auto srcLayoutAttr = cast<PackLayoutAttr>(attr);
132+
return srcLayoutAttr.coalesce() != targetLayoutAttr.coalesce();
133+
}
134+
135+
static VectorLayoutInterface
136+
getRecombinedLayout(ArrayRef<VectorLayoutInterface> layouts,
137+
ArrayRef<AffineMap> maps, AffineMap resultMap) {
138+
if (!llvm::all_of(layouts, llvm::IsaPred<PackLayoutAttr>)) {
139+
return VectorLayoutInterface();
140+
}
141+
MLIRContext *ctx = resultMap.getContext();
142+
143+
SmallVector<PackLayoutAttr> packLayouts;
144+
llvm::transform(layouts, std::back_inserter(packLayouts),
145+
[](VectorLayoutInterface layout) {
146+
return cast<PackLayoutAttr>(layout);
147+
});
148+
149+
int64_t resRank = resultMap.getNumResults();
150+
151+
// Null Attribute serves as "not yet assigned" sentinel.
152+
Attribute unset;
153+
SmallVector<Attribute> modeShapes(resRank, unset);
154+
SmallVector<Attribute> modeStrides(resRank, unset);
155+
156+
for (auto [layout, indexingMap] : llvm::zip(packLayouts, maps)) {
157+
for (int64_t resultIdx = 0;
158+
resultIdx < static_cast<int64_t>(indexingMap.getNumResults());
159+
++resultIdx) {
160+
auto dimExpr =
161+
dyn_cast<AffineDimExpr>(indexingMap.getResult(resultIdx));
162+
if (!dimExpr) {
163+
continue;
164+
}
165+
int64_t iterPos = dimExpr.getPosition();
166+
auto maybeResPos =
167+
resultMap.getResultPosition(getAffineDimExpr(iterPos, ctx));
168+
if (!maybeResPos.has_value()) {
169+
continue;
170+
}
171+
int64_t resPos = maybeResPos.value();
172+
173+
Attribute ms = layout.getShapeMode(resultIdx);
174+
Attribute md = layout.getStrideMode(resultIdx);
175+
176+
if (modeShapes[resPos] && modeShapes[resPos] != ms) {
177+
return VectorLayoutInterface();
178+
}
179+
modeShapes[resPos] = ms;
180+
modeStrides[resPos] = md;
181+
}
182+
}
183+
184+
for (int64_t i = 0; i < resRank; ++i) {
185+
if (!modeShapes[i]) {
186+
modeShapes[i] = makeLeaf(ctx, 1);
187+
modeStrides[i] = makeLeaf(ctx, 0);
188+
}
189+
}
190+
191+
return VectorLayoutInterface(PackLayoutAttr::get(
192+
ctx, makeTuple(ctx, modeShapes), makeTuple(ctx, modeStrides)));
193+
}
194+
};
195+
196+
} // namespace
197+
198+
void registerVectorLayoutInterfaceExternalModels(DialectRegistry &registry) {
199+
registry.addExtension(+[](MLIRContext *context, IREEMapDialect *dialect) {
200+
PackLayoutAttr::attachInterface<PackLayoutModel>(*context);
201+
});
202+
}
203+
204+
} // namespace mlir::iree_compiler::IREE::Map
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_VECTORLAYOUTINTERFACEIMPL_H_
8+
#define IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_VECTORLAYOUTINTERFACEIMPL_H_
9+
10+
#include "mlir/IR/DialectRegistry.h"
11+
12+
namespace mlir::iree_compiler::IREE::Map {
13+
14+
void registerVectorLayoutInterfaceExternalModels(DialectRegistry &registry);
15+
16+
} // namespace mlir::iree_compiler::IREE::Map
17+
18+
#endif // IREE_COMPILER_CODEGEN_DIALECT_MAP_EXTERNALINTERFACES_VECTORLAYOUTINTERFACEIMPL_H_
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2026 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
8+
9+
package(
10+
features = ["layering_check"],
11+
licenses = ["notice"], # Apache 2.0
12+
)
13+
14+
iree_lit_test_suite(
15+
name = "lit",
16+
srcs = [
17+
"pack_layout_vector_analysis.mlir",
18+
],
19+
cfg = "//compiler:lit.cfg.py",
20+
tools = [
21+
"//tools:iree-opt",
22+
"@llvm-project//llvm:FileCheck",
23+
],
24+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
################################################################################
2+
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
3+
# compiler/src/iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/test/BUILD.bazel#
4+
# #
5+
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
6+
# CMake-only content. #
7+
# #
8+
# To disable autogeneration for this file entirely, delete this header. #
9+
################################################################################
10+
11+
iree_add_all_subdirs()
12+
13+
iree_lit_test_suite(
14+
NAME
15+
lit
16+
SRCS
17+
"pack_layout_vector_analysis.mlir"
18+
TOOLS
19+
FileCheck
20+
iree-opt
21+
)
22+
23+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###

0 commit comments

Comments
 (0)