Skip to content

Commit 314b4e6

Browse files
qedawkinsclaude
andauthored
[Codegen] Add PipelineAttrInterface and PassPipelineAttr (#23590)
This change adds an attribute interface for representing pass pipelines and a single basic attribute that uses the string based pass interpreter to populate a pipeline. The intent of this change is NOT to induce a refactor of all the pass pipelines, instead it's primarily to make testing structural pipeline changes with partially lowered inputs much easier. Today if you want to work on a change that affects later stages of a pass pipeline but will also require changes to earlier steps, it's hard to stage those changes since there isn't a convenient way to jump into the middle of a codegen pass pipeline (unlike the rest of the compiler which offers distinct stages). --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d8b6df3 commit 314b4e6

27 files changed

Lines changed: 489 additions & 174 deletions

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ iree_compiler_cc_library(
115115
"@llvm-project//mlir:LinalgDialect",
116116
"@llvm-project//mlir:MemRefDialect",
117117
"@llvm-project//mlir:Parser",
118+
"@llvm-project//mlir:Pass",
118119
"@llvm-project//mlir:SCFDialect",
119120
"@llvm-project//mlir:SCFTransforms",
120121
"@llvm-project//mlir:Support",

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ iree_cc_library(
6666
MLIRLinalgDialect
6767
MLIRMemRefDialect
6868
MLIRParser
69+
MLIRPass
6970
MLIRSCFDialect
7071
MLIRSCFTransforms
7172
MLIRSupport

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,50 @@
1919
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2020
#include "mlir/IR/DialectImplementation.h"
2121
#include "mlir/IR/StorageUniquerSupport.h"
22+
#include "mlir/Pass/PassRegistry.h"
23+
24+
// Custom parse/print directives for TranslationInfoAttr's pipeline field.
25+
// These must be defined before the generated .cpp.inc is included because
26+
// the ODS-generated parse/print methods call them.
27+
namespace mlir::iree_compiler::IREE::Codegen {
28+
29+
/// Parses either a DispatchLoweringPassPipeline enum keyword (e.g.,
30+
/// `CPUDefault`) or a generic attribute implementing PipelineAttrInterface
31+
/// (e.g., `#iree_codegen.pass_pipeline<"canonicalize">`).
32+
static ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) {
33+
StringRef keyword;
34+
SMLoc loc = parser.getCurrentLocation();
35+
if (succeeded(parser.parseOptionalKeyword(&keyword))) {
36+
std::optional<DispatchLoweringPassPipeline> pipeline =
37+
symbolizeDispatchLoweringPassPipeline(keyword);
38+
if (!pipeline) {
39+
parser.emitError(loc, "unknown pipeline keyword: ") << keyword;
40+
return failure();
41+
}
42+
result =
43+
DispatchLoweringPassPipelineAttr::get(parser.getContext(), *pipeline);
44+
return success();
45+
}
46+
Attribute attr;
47+
if (parser.parseAttribute(attr)) {
48+
return failure();
49+
}
50+
result = attr;
51+
return success();
52+
}
53+
54+
/// Prints DispatchLoweringPassPipelineAttr as a bare keyword and other
55+
/// attributes (e.g., PipelineAttrInterface impls) via the generic printer.
56+
static void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr) {
57+
if (auto enumAttr =
58+
dyn_cast<DispatchLoweringPassPipelineAttr>(pipelineAttr)) {
59+
printer << stringifyEnum(enumAttr.getValue());
60+
return;
61+
}
62+
printer.printAttribute(pipelineAttr);
63+
}
64+
65+
} // namespace mlir::iree_compiler::IREE::Codegen
2266

2367
#define GET_ATTRDEF_CLASSES
2468
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp.inc"
@@ -64,6 +108,28 @@ ArrayAttr ExportConfigAttr::getWorkgroupSizeIndexArray() {
64108
return getIndexArrayAttr(getContext(), getWorkgroupSize());
65109
}
66110

111+
//===----------------------------------------------------------------------===//
112+
// iree_codegen.pass_pipeline
113+
//===----------------------------------------------------------------------===//
114+
115+
LogicalResult PassPipelineAttr::buildPipeline(OpPassManager &pm) const {
116+
if (failed(parsePassPipeline(getPipeline(), pm))) {
117+
return failure();
118+
}
119+
return success();
120+
}
121+
122+
LogicalResult
123+
PassPipelineAttr::verify(function_ref<InFlightDiagnostic()> emitError,
124+
StringRef pipeline) {
125+
OpPassManager pm("builtin.module");
126+
if (failed(parsePassPipeline(pipeline, pm))) {
127+
return emitError() << "invalid pass pipeline specification: '" << pipeline
128+
<< "'";
129+
}
130+
return success();
131+
}
132+
67133
//===----------------------------------------------------------------------===//
68134
// iree_codegen.translation_info
69135
//===----------------------------------------------------------------------===//
@@ -72,7 +138,7 @@ TranslationInfoAttr TranslationInfoAttr::get(
72138
MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
73139
SymbolRefAttr codegenSpec, ArrayRef<int64_t> workgroupSize,
74140
std::optional<int64_t> subgroupSize, DictionaryAttr configuration) {
75-
auto pipelineAttr =
141+
Attribute pipelineAttr =
76142
DispatchLoweringPassPipelineAttr::get(context, passPipeline);
77143
return get(context, pipelineAttr, codegenSpec, workgroupSize,
78144
subgroupSize.value_or(int64_t()), configuration);
@@ -82,36 +148,46 @@ TranslationInfoAttr TranslationInfoAttr::get(
82148
MLIRContext *context, DispatchLoweringPassPipeline passPipeline,
83149
ArrayRef<int64_t> workgroupSize, std::optional<int64_t> subgroupSize,
84150
DictionaryAttr configuration) {
85-
auto pipelineAttr =
151+
Attribute pipelineAttr =
86152
DispatchLoweringPassPipelineAttr::get(context, passPipeline);
87153
return get(context, pipelineAttr, /*codegenSpec=*/SymbolRefAttr(),
88154
workgroupSize, subgroupSize.value_or(int64_t()), configuration);
89155
}
90156

91157
DispatchLoweringPassPipeline
92158
TranslationInfoAttr::getDispatchLoweringPassPipeline() {
93-
return getPassPipeline().getValue();
159+
if (auto enumAttr =
160+
dyn_cast<DispatchLoweringPassPipelineAttr>(getPassPipeline())) {
161+
return enumAttr.getValue();
162+
}
163+
return DispatchLoweringPassPipeline::None;
94164
}
95165

96166
LogicalResult TranslationInfoAttr::verify(
97-
function_ref<InFlightDiagnostic()> emitError,
98-
IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline,
167+
function_ref<InFlightDiagnostic()> emitError, Attribute passPipeline,
99168
SymbolRefAttr codegenSpec, ArrayRef<int64_t> workgroupSize,
100169
int64_t subgroupSize, DictionaryAttr configuration) {
101170
if (!passPipeline) {
102171
return emitError() << "missing pass pipeline specification";
103172
}
104-
auto passPipelineValue = passPipeline.getValue();
105-
if (passPipelineValue > IREE::Codegen::DispatchLoweringPassPipeline::None) {
106-
return emitError() << "invalid pass pipeline value : "
107-
<< stringifyEnum(passPipeline.getValue());
108-
}
109-
auto tdPassPipeline =
110-
IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen;
111-
if (codegenSpec && passPipelineValue != tdPassPipeline) {
173+
if (auto enumAttr =
174+
dyn_cast<DispatchLoweringPassPipelineAttr>(passPipeline)) {
175+
DispatchLoweringPassPipeline passPipelineValue = enumAttr.getValue();
176+
if (passPipelineValue > IREE::Codegen::DispatchLoweringPassPipeline::None) {
177+
return emitError() << "invalid pass pipeline value : "
178+
<< stringifyEnum(passPipelineValue);
179+
}
180+
DispatchLoweringPassPipeline tdPassPipeline =
181+
IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen;
182+
if (codegenSpec && passPipelineValue != tdPassPipeline) {
183+
return emitError()
184+
<< "transform dialect codegen spec requires pass pipeline : "
185+
<< stringifyEnum(tdPassPipeline);
186+
}
187+
} else if (!isa<PipelineAttrInterface>(passPipeline)) {
112188
return emitError()
113-
<< "transform dialect codegen spec requires pass pipeline : "
114-
<< stringifyEnum(tdPassPipeline);
189+
<< "pass pipeline must be a DispatchLoweringPassPipelineAttr or "
190+
"implement PipelineAttrInterface";
115191
}
116192
if (workgroupSize.size() > 3) {
117193
return emitError() << "workgroup size cannot have more than 3 entries";

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,32 @@ def IREECodegen_SimpleTargetAttr :
250250
}
251251

252252

253+
//===---------------------------------------------------------------------===//
254+
// iree_codegen.pass_pipeline
255+
//===---------------------------------------------------------------------===//
256+
257+
def IREECodegen_PassPipelineAttr :
258+
AttrDef<IREECodegen_Dialect, "PassPipeline", [
259+
DeclareAttrInterfaceMethods<IREECodegen_PipelineAttrInterface, [
260+
"buildPipeline"
261+
]>
262+
]> {
263+
let mnemonic = "pass_pipeline";
264+
let summary = "An attribute carrying a textual pass pipeline string.";
265+
let description = [{
266+
Specifies a pass pipeline using MLIR's textual pass pipeline syntax.
267+
The pipeline string is parsed and populated into an OpPassManager
268+
when `buildPipeline` is called.
269+
}];
270+
let parameters = (ins
271+
StringRefParameter<"The textual pass pipeline specification">:$pipeline
272+
);
273+
let assemblyFormat = [{
274+
`<` $pipeline `>`
275+
}];
276+
let genVerifyDecl = 1;
277+
}
278+
253279
//===---------------------------------------------------------------------===//
254280
// iree_codegen.translation_info
255281
//===---------------------------------------------------------------------===//
@@ -269,22 +295,15 @@ def IREECodegen_TranslationInfoAttr :
269295
dispatch region (like `linalg.matmul`/`linalg.*conv*`), this
270296
attribute gets propagated to the entry point function.
271297

272-
The fields are
273-
- `passPipeline` : The pass pipeline to use.
274-
275-
}];
276-
277-
let assemblyFormat = [{
278-
`<` `pipeline` `=` `` $passPipeline
279-
(`codegen_spec` `=` $codegenSpec^)?
280-
(`workgroup_size` `=` `[` $workgroupSize^ `]`)?
281-
(`subgroup_size` `=` $subgroupSize^)?
282-
(`,` $configuration^)? `>`
298+
The `passPipeline` field can be either:
299+
- A `DispatchLoweringPassPipelineAttr` (enum keyword like `CPUDefault`).
300+
- Any attribute implementing `PipelineAttrInterface` (e.g.,
301+
`#iree_codegen.pass_pipeline<"...">`).
283302
}];
284303

285304
let parameters = (ins
286-
AttrParameter<"IREE::Codegen::DispatchLoweringPassPipelineAttr",
287-
"Name of the pipeline to be invoked on the translation unit.">:$passPipeline,
305+
AttrParameter<"Attribute",
306+
"Pass pipeline specification.">:$passPipeline,
288307
OptionalParameter<"SymbolRefAttr",
289308
"The symbol pointing to the transform dialect codegen spec to be used">:$codegenSpec,
290309
OptionalArrayRefParameter<"int64_t", "The workgroup size to use">:$workgroupSize,
@@ -304,9 +323,20 @@ def IREECodegen_TranslationInfoAttr :
304323
CArg<"DictionaryAttr", "{}">:$configuration)>
305324
];
306325
let extraClassDeclaration = [{
307-
// Returns the lowering pass pipeline set.
326+
// Returns the lowering pass pipeline enum value. Returns None if the
327+
// pipeline is not a DispatchLoweringPassPipelineAttr.
308328
DispatchLoweringPassPipeline getDispatchLoweringPassPipeline();
309329
}];
330+
331+
let assemblyFormat = [{
332+
`<` `pipeline` `=` custom<PipelineAttr>($passPipeline)
333+
(`codegen_spec` `=` $codegenSpec^)?
334+
(`workgroup_size` `=` `[` $workgroupSize^ `]`)?
335+
(`subgroup_size` `=` $subgroupSize^)?
336+
(`,` $configuration^)?
337+
`>`
338+
}];
339+
310340
let genVerifyDecl = 1;
311341
}
312342

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/Builders.h"
1515
#include "mlir/IR/BuiltinAttributes.h"
1616
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/Pass/PassManager.h"
1718

1819
#include "llvm/ADT/STLExtras.h"
1920

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,26 @@ def IREECodegen_UKernelProviderInterface :
752752
];
753753
}
754754

755+
def IREECodegen_PipelineAttrInterface :
756+
AttrInterface<"PipelineAttrInterface"> {
757+
let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
758+
let description = [{
759+
Attribute interface for building a pass pipeline. Implementations populate
760+
the provided OpPassManager with the desired pass pipeline.
761+
}];
762+
763+
let methods = [
764+
InterfaceMethod<
765+
/*desc=*/[{
766+
Populates the given pass manager with a pass pipeline.
767+
}],
768+
/*retTy=*/"::mlir::LogicalResult",
769+
/*methodName=*/"buildPipeline",
770+
/*args=*/(ins "::mlir::OpPassManager &":$pm)
771+
>
772+
];
773+
}
774+
755775
def IREECodegen_TargetInfoAttrInterface :
756776
AttrInterface<"TargetInfoAttrInterface"> {
757777
let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/lowering_config_attr.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,37 @@ module {
125125
}
126126
}
127127
}
128+
129+
// -----
130+
131+
module {
132+
/// Pass pipeline attribute round-trips correctly.
133+
func.func @test_pass_pipeline() attributes {
134+
translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>} {
135+
return
136+
}
137+
}
138+
// CHECK: #translation = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize">>
139+
140+
// -----
141+
142+
module {
143+
/// Pass pipeline attribute with workgroup size and subgroup size round-trips.
144+
func.func @test_pass_pipeline_with_config() attributes {
145+
translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize"> workgroup_size = [64, 1, 1] subgroup_size = 32>} {
146+
return
147+
}
148+
}
149+
// CHECK: #translation = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"canonicalize"> workgroup_size = [64, 1, 1] subgroup_size = 32>
150+
151+
// -----
152+
153+
module {
154+
/// Invalid pass pipeline string should be caught at verify time.
155+
func.func @invalid_pass_pipeline() attributes {
156+
// expected-error @+1 {{invalid pass pipeline specification: 'not_a_real_pass'}}
157+
translation_info = #iree_codegen.translation_info<pipeline = #iree_codegen.pass_pipeline<"not_a_real_pass">>
158+
} {
159+
return
160+
}
161+
}

compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3976,9 +3976,10 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn,
39763976
}
39773977
}
39783978

3979-
auto tInfo = getTranslationInfo(entryPointFn);
3980-
auto pipeline = tInfo.getPassPipeline().getValue();
3981-
auto pipelineConfig = tInfo.getConfiguration();
3979+
IREE::Codegen::TranslationInfoAttr tInfo = getTranslationInfo(entryPointFn);
3980+
DispatchLoweringPassPipeline pipeline =
3981+
tInfo.getDispatchLoweringPassPipeline();
3982+
DictionaryAttr pipelineConfig = tInfo.getConfiguration();
39823983
if (isOptEnabled(entryPointFn, getEnableLoopPeelingStr())) {
39833984
// See #16406
39843985
LDBG() << "unpack fusion does not work with peeling, falling back to "
@@ -4167,7 +4168,8 @@ setTranslationInfoAndRootConfig(mlir::FunctionOpInterface entryPointFn,
41674168

41684169
// The transform dialect codegen has different logics and codegen flow.
41694170
// Ignore the tile sizes adjustment.
4170-
auto pipeline = getTranslationInfo(entryPointFn).getPassPipeline().getValue();
4171+
DispatchLoweringPassPipeline pipeline =
4172+
getTranslationInfo(entryPointFn).getDispatchLoweringPassPipeline();
41714173
if (pipeline != DispatchLoweringPassPipeline::TransformDialectCodegen) {
41724174
if (failed(adjustTileSizesForRootUnPackOp(entryPointFn, rootOperation))) {
41734175
return failure();

0 commit comments

Comments
 (0)