From 1a590b904e9bdda43d550f31339a407046ca4bcf Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sat, 7 Mar 2026 08:54:30 -0500 Subject: [PATCH 1/2] [Tuner][Codegen] Add iree_codegen.constraints op and supporting infra Today, the tuner has to know a lot about the IREE compiler and its compilation pipelines, including what inputs they can compile, their configuration space, their lowering_config format, etc. This is partially encoded directly as Python logic and partially exposed through Python bindings. This is the first step towards moving the constraint generation responsibility from the tuner to the compiler. The constraints op encodes pipeline constraints over one or more root ops and can later be lowered to SMT-LIB for solving by the tuner, or verified against the selected lowering_config after dispatch configuration. See the discussion for the full proposal: https://github.com/iree-org/iree/discussions/23521 Key design decisions: - The knobs dict structurally mirrors the lowering_config / translation_info attributes, with tunable leaves as `#iree_codegen.int_knob<"name">`. This makes it possible for the tuner to mechanically substitute solved knob values back into concrete attributes without understanding their structure. - Problem dimensions come in as index operands (from tensor.dim / constants), so that static sizes get constant-folded into the SMT body and dynamic shapes remain symbolic. - The body uses upstream SMT dialect ops, so that we can directly export to SMT-LIB without a custom lowering. - Pipeline attr accepts both DispatchLoweringPassPipelineAttr and the new PipelineAttrInterface (#23590) to support custom pipelines. Co-Authored-By: Claude Opus 4.6 --- .../bazel_to_cmake/bazel_to_cmake_targets.py | 6 + compiler/bindings/python/CMakeLists.txt | 1 + .../bindings/python/test/ir/dialects_test.py | 1 + compiler/src/iree/compiler/API/BUILD.bazel | 1 + compiler/src/iree/compiler/API/CMakeLists.txt | 8 +- .../Codegen/Dialect/Codegen/IR/BUILD.bazel | 2 + .../Codegen/Dialect/Codegen/IR/CMakeLists.txt | 1 + .../Dialect/Codegen/IR/IREECodegenAttrs.cpp | 4 +- .../Dialect/Codegen/IR/IREECodegenAttrs.h | 19 ++++ .../Dialect/Codegen/IR/IREECodegenAttrs.td | 18 +++ .../Dialect/Codegen/IR/IREECodegenDialect.cpp | 1 + .../Dialect/Codegen/IR/IREECodegenDialect.td | 5 +- .../Dialect/Codegen/IR/IREECodegenOps.cpp | 97 ++++++++++++++++ .../Dialect/Codegen/IR/IREECodegenOps.h | 1 + .../Dialect/Codegen/IR/IREECodegenOps.td | 89 +++++++++++++++ .../Dialect/Codegen/IR/test/invalid.mlir | 106 ++++++++++++++++++ .../Dialect/Codegen/IR/test/roundtrip.mlir | 101 +++++++++++++++++ compiler/src/iree/compiler/Tools/BUILD.bazel | 1 + .../src/iree/compiler/Tools/CMakeLists.txt | 1 + .../iree/compiler/Tools/init_mlir_dialects.h | 2 + 20 files changed, 461 insertions(+), 4 deletions(-) diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index fbe8a4171619..7b6a74ed1731 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -81,6 +81,12 @@ def __init__(self, repo_map: Dict[str, str]): "@llvm-project//mlir:MlirLspServerLib": ["MLIRLspServerLib"], "@llvm-project//mlir:MlirTableGenMain": ["MLIRTableGen"], "@llvm-project//mlir:MlirOptLib": ["MLIROptLib"], + "@llvm-project//mlir:CAPISMT": [ + "MLIRCAPISMT", + "MLIRCAPIExportSMTLIB", + ], + "@llvm-project//mlir:SMTDialect": ["MLIRSMT"], + "@llvm-project//mlir:TargetSMTLIB": ["MLIRExportSMTLIB"], "@llvm-project//mlir:VectorOps": ["MLIRVector"], # StableHLO. "@stablehlo//:chlo_ops": [ diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index 7dd7c559588c..2e695bc24f29 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -248,6 +248,7 @@ set(_SOURCE_COMPONENTS MLIRPythonSources.Dialects.rocdl MLIRPythonSources.Dialects.scf MLIRPythonSources.Dialects.shape + MLIRPythonSources.Dialects.smt MLIRPythonSources.Dialects.structured_transform MLIRPythonSources.Dialects.tensor MLIRPythonSources.Dialects.tosa diff --git a/compiler/bindings/python/test/ir/dialects_test.py b/compiler/bindings/python/test/ir/dialects_test.py index 5a5098946049..76f52bd76a38 100644 --- a/compiler/bindings/python/test/ir/dialects_test.py +++ b/compiler/bindings/python/test/ir/dialects_test.py @@ -39,6 +39,7 @@ def decorator_builder(func): rocdl, scf, shape, + smt, tensor, tosa, transform, diff --git a/compiler/src/iree/compiler/API/BUILD.bazel b/compiler/src/iree/compiler/API/BUILD.bazel index 43ff10395ded..9ca3ea751001 100644 --- a/compiler/src/iree/compiler/API/BUILD.bazel +++ b/compiler/src/iree/compiler/API/BUILD.bazel @@ -46,6 +46,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:CAPILLVM", "@llvm-project//mlir:CAPILinalg", "@llvm-project//mlir:CAPIPDL", + "@llvm-project//mlir:CAPISMT", "@llvm-project//mlir:CAPITarget", "@llvm-project//mlir:CAPITransformDialect", "@llvm-project//mlir:CAPITransformDialectTransforms", diff --git a/compiler/src/iree/compiler/API/CMakeLists.txt b/compiler/src/iree/compiler/API/CMakeLists.txt index dce278ff2f91..30f6c7a8d519 100644 --- a/compiler/src/iree/compiler/API/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/CMakeLists.txt @@ -17,12 +17,14 @@ iree_cc_library( IREEDialectsCAPI MLIRCAPIAMDGPU MLIRCAPIDebug + MLIRCAPIExportSMTLIB MLIRCAPIGPU MLIRCAPIIR MLIRCAPIInterfaces MLIRCAPILLVM MLIRCAPILinalg MLIRCAPIPDL + MLIRCAPISMT MLIRCAPITarget MLIRCAPITransformDialect MLIRCAPITransformDialectTransforms @@ -80,6 +82,8 @@ set(_EXPORT_OBJECT_LIBS obj.MLIRCAPILLVM obj.MLIRCAPILinalg obj.MLIRCAPIPDL + obj.MLIRCAPISMT + obj.MLIRCAPIExportSMTLIB obj.MLIRCAPITarget obj.MLIRCAPITransforms obj.MLIRCAPITransformDialect @@ -194,9 +198,11 @@ target_link_libraries(iree_compiler_API_SharedImpl PRIVATE ${_EXPORT_OBJECT_DEPS} ) -# Link MLIRTargetLLVMIRImport directly since it is not exported as an object library. +# Link MLIRTargetLLVMIRImport and MLIRExportSMTLIB directly since they are +# not exported as object libraries. target_link_libraries(iree_compiler_API_SharedImpl PRIVATE MLIRTargetLLVMIRImport + MLIRExportSMTLIB ) # If not using sanitizers, ask linkers to error on undefined symbols. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel index d0c06f7ce140..a502ee974f6c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel @@ -44,6 +44,7 @@ iree_td_library( "@llvm-project//mlir:LinalgOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SCFTdFiles", + "@llvm-project//mlir:SMTTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", "@llvm-project//mlir:TilingInterfaceTdFiles", "@llvm-project//mlir:VectorInterfacesTdFiles", @@ -118,6 +119,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SMTDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TilingInterface", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt index d26c561bcab3..ce709de10b83 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt @@ -69,6 +69,7 @@ iree_cc_library( MLIRPass MLIRSCFDialect MLIRSCFTransforms + MLIRSMT MLIRSupport MLIRTensorDialect MLIRTilingInterface diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp index 441e612b08ec..a90188a53a3c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp @@ -29,7 +29,7 @@ namespace mlir::iree_compiler::IREE::Codegen { /// Parses either a DispatchLoweringPassPipeline enum keyword (e.g., /// `CPUDefault`) or a generic attribute implementing PipelineAttrInterface /// (e.g., `#iree_codegen.pass_pipeline<"canonicalize">`). -static ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) { +ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) { StringRef keyword; SMLoc loc = parser.getCurrentLocation(); if (succeeded(parser.parseOptionalKeyword(&keyword))) { @@ -53,7 +53,7 @@ static ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) { /// Prints DispatchLoweringPassPipelineAttr as a bare keyword and other /// attributes (e.g., PipelineAttrInterface impls) via the generic printer. -static void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr) { +void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr) { if (auto enumAttr = dyn_cast(pipelineAttr)) { printer << stringifyEnum(enumAttr.getValue()); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h index 474b8444ee4a..0556dc2afda8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h @@ -36,6 +36,25 @@ bool shouldSetTunerAttributes(); #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h.inc" // clang-format on +namespace mlir::iree_compiler::IREE::Codegen { + +/// Parses either a DispatchLoweringPassPipeline enum keyword or a generic +/// attribute implementing PipelineAttrInterface. +ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result); +inline ParseResult parsePipelineAttr(OpAsmParser &parser, Attribute &result) { + return parsePipelineAttr(static_cast(parser), result); +} + +/// Prints DispatchLoweringPassPipelineAttr as a bare keyword and other +/// attributes via the generic printer. +void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr); +inline void printPipelineAttr(OpAsmPrinter &printer, Operation *, + Attribute pipelineAttr) { + printPipelineAttr(printer, pipelineAttr); +} + +} // namespace mlir::iree_compiler::IREE::Codegen + namespace mlir::iree_compiler { //===----------------------------------------------------------------------===// // Constant names. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index c43a48a29fb4..7892e47f759e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -732,4 +732,22 @@ def IREECodegen_RootOpAttr : AttrDef { let assemblyFormat = "`<` `set` `=` $set `>`"; } +//===---------------------------------------------------------------------===// +// iree_codegen.int_knob +//===---------------------------------------------------------------------===// + +def IREECodegen_IntKnobAttr : AttrDef { + let mnemonic = "int_knob"; + let summary = "Integer-valued tunable knob placeholder."; + let description = [{ + Represents a named placeholder for an integer tunable parameter in a + constraints knobs dictionary. During constraint generation, these appear + in tiling arrays (workgroup, reduction, thread), workgroup_size, and + subgroup_size positions. The name matches the corresponding + `iree_codegen.knob` op name. + }]; + let parameters = (ins "StringAttr":$name); + let assemblyFormat = "`<` $name `>`"; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp index 48b5b02f8dbb..431fe3c8c6de 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h" #include "iree/compiler/Codegen/Dialect/PCF/IR/PCFInterfaces.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/IR/DialectImplementation.h" // clang-format off diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td index 7ac568fde743..cbf306b55b85 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td @@ -107,7 +107,10 @@ def IREECodegen_Dialect : Dialect { // We use linalg.iterator_type to avoid clashing definitions for the // enum attribute that wraps mlir::utils::IteratorType, so we need to be // able to parse it. - let dependentDialects = ["::mlir::linalg::LinalgDialect"]; + let dependentDialects = [ + "::mlir::linalg::LinalgDialect", + "::mlir::smt::SMTDialect", + ]; let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; let hasOperationAttrVerify = 1; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp index 98e6571b4cfc..496cb93ec2f4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp @@ -6,14 +6,17 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineMap.h" @@ -22,6 +25,23 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" +// Custom parse/print helper for the knobs dictionary in constraints op. +// Prints `knobs = { ... }` on its own line with newlines before and after. +static mlir::ParseResult parseKnobsDictionary(mlir::OpAsmParser &parser, + mlir::DictionaryAttr &attr) { + if (parser.parseKeyword("knobs") || parser.parseEqual()) { + return mlir::failure(); + } + return parser.parseAttribute(attr); +} +static void printKnobsDictionary(mlir::OpAsmPrinter &p, mlir::Operation *, + mlir::DictionaryAttr attr) { + p.printNewline(); + p << " knobs = "; + p.printAttributeWithoutType(attr); + p.printNewline(); +} + // clang-format off #define GET_OP_CLASSES #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp.inc" // IWYU pragma: keep @@ -452,3 +472,80 @@ void WorkgroupCountHintOp::build(OpBuilder &builder, OperationState &state, build(builder, state, dynamicSizes, builder.getDenseI64ArrayAttr(staticSizes)); } + +//===----------------------------------------------------------------------===// +// ConstraintsOp +//===----------------------------------------------------------------------===// + +/// Recursively check whether `name` appears as a knob name in `attr`. +/// Checks IntKnobAttr names and recurses into DictionaryAttr/ArrayAttr. +static bool hasKnobName(Attribute attr, StringRef name) { + if (auto intKnob = dyn_cast(attr)) { + return intKnob.getName().getValue() == name; + } + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute entry : dictAttr) { + if (hasKnobName(entry.getValue(), name)) { + return true; + } + } + return false; + } + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) { + if (hasKnobName(element, name)) { + return true; + } + } + return false; + } + return false; +} + +LogicalResult ConstraintsOp::verify() { + Block &block = getBody().front(); + + // Check block arg count matches problem_dims count. + if (block.getNumArguments() != getProblemDims().size()) { + return emitOpError("expected ") + << getProblemDims().size() << " block arguments but got " + << block.getNumArguments(); + } + + // Check all block args are !smt.int. + smt::IntType smtIntType = smt::IntType::get(getContext()); + for (auto [i, arg] : llvm::enumerate(block.getArguments())) { + if (arg.getType() != smtIntType) { + return emitOpError("block argument #") + << i << " must be !smt.int but got " << arg.getType(); + } + } + + // Verify knob ops: check names exist in the dict and reject duplicates. + // Note that we considered using SymbolTable for uniqueness, but the knobs + // dictionary contains attributes (not ops), so we'd still need custom + // verification for dictionary <--> KnobOp correspondence. + // Rejecting duplicates is not just pedantic -- when this op is lowered to + // SMT, each KnobOp becomes an `smt.declare_const`. The SMT dialect creates + // a fresh symbolic constant per declaration regardless of the name string, + // so two KnobOps with the same name would silently introduce two independent + // solver variables where one was intended, producing incorrect constraints. + DictionaryAttr knobs = getKnobsAttr(); + llvm::StringMap seenKnobs; + for (auto knobOp : block.getOps()) { + auto [it, inserted] = + seenKnobs.try_emplace(knobOp.getName(), knobOp.getLoc()); + if (!inserted) { + InFlightDiagnostic diag = knobOp.emitOpError("duplicate knob name '") + << knobOp.getName() << "'"; + diag.attachNote(it->second) << "first occurrence here"; + return diag; + } + if (!hasKnobName(knobs, knobOp.getName())) { + return knobOp.emitOpError("knob name '") + << knobOp.getName() << "' not found in knobs dict"; + } + } + + return success(); +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h index 5ad4e7a1e052..8ca29fbeddb8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h @@ -12,6 +12,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td index 6b314d6cb211..6f91d185f205 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td @@ -22,6 +22,7 @@ include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Dialect/SMT/IR/SMTTypes.td" def TensorTypeAttr : TypeAttrBase<"::mlir::TensorType", "Tensor type attribute">; @@ -615,4 +616,92 @@ def IREECodegen_IndexHintOp : Op { let assemblyFormat = "$input `(` $hint `)` attr-dict `:` type($input)"; } +//===----------------------------------------------------------------------===// +// ConstraintsOp +//===----------------------------------------------------------------------===// + +def IREECodegen_ConstraintsOp : + Op { + let summary = "SMT constraints for a codegen configuration of root ops."; + let description = [{ + Declares SMT constraints over problem dimensions and configuration + knobs for a codegen pipeline, targeting a set of root ops. + + `target`: A `#iree_codegen.root_op` attribute identifying + which root op set these constraints apply to. All ops marked with + the same `#iree_codegen.root_op` attribute share the same + lowering config. This decouples constraints from SSA values, so ops + with zero or multiple results are supported. + + `pipeline`: The codegen pipeline to use. This is a fixed choice, + not decided by the solver. + + `knobs`: DictionaryAttr mirroring GPULoweringConfigAttr. Leaves + that are `#iree_codegen.int_knob<"name">` attrs name tunable SMT + constants (materialized by `iree_codegen.knob` ops in the body). + Integer/attr leaves are fixed. + + `problem_dims`: index-typed problem dimensions; corresponding block + arguments are !smt.int. + + Does not have any execution semantics and is meant to be used by the + tuner or verification passes, and erased before lowering. + + Example: + ```mlir + // The matmul is marked: {root_op = #iree_codegen.root_op} + iree_codegen.constraints target = , pipeline = LLVMGPUVectorDistribute, + knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">, #iree_codegen.int_knob<"wg_n">]} + dims(%M, %N, %K) { + ^bb0(%m: !smt.int, %n: !smt.int, %k: !smt.int): + %wg_m = iree_codegen.knob "wg_m" : !smt.int + %wg_n = iree_codegen.knob "wg_n" : !smt.int + } + ``` + }]; + let arguments = (ins + IREECodegen_RootOpAttr:$target, + AnyAttrOf<[DispatchLoweringPassPipelineAttr, + IREECodegen_PipelineAttrInterface]>:$pipeline, + DictionaryAttr:$knobs, + Variadic:$problem_dims + ); + let regions = (region SizedRegion<1>:$body); + let results = (outs); + let assemblyFormat = [{ + `target` `=` $target `,` `pipeline` `=` custom($pipeline) `,` + custom($knobs) + `dims` `(` $problem_dims `)` attr-dict-with-keyword + $body + }]; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// KnobOp +//===----------------------------------------------------------------------===// + +def IREECodegen_KnobOp : + Op]> { + let summary = "Declare an SMT constant for a tunable configuration knob."; + let description = [{ + Materializes a named SMT constant (!smt.int) for use in constraint + expressions. The name must match an `#iree_codegen.int_knob<"name">` + leaf in the enclosing `iree_codegen.constraints` op's `knobs` + dictionary. + + In SMT terminology this is a constant (0-ary function), not a variable. + The tuner assigns concrete integer values to these constants. + + Example: + ```mlir + %wg_m = iree_codegen.knob "wg_m" : !smt.int + ``` + }]; + let arguments = (ins StrAttr:$name); + let results = (outs IntType:$result); + let assemblyFormat = "$name attr-dict `:` type($result)"; +} + #endif // IREE_CODEGEN_DIALECT_IREECODEGENOPS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/invalid.mlir b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/invalid.mlir index c5795917ab74..bf65a168af8b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/invalid.mlir @@ -42,3 +42,109 @@ func.func @store_to_buffer_invalid_element_type(%arg0: tensor<4xf16>, %arg1: mem iree_codegen.store_to_buffer %arg0, %arg1 : tensor<4xf16> into memref<4xf32> return } + +// ----- + +// Constraints op: block arg wrong type. +func.func @constraints_block_arg_wrong_type(%arg0: index) { + // expected-error @+1 {{'iree_codegen.constraints' op block argument #0 must be !smt.int but got 'index'}} + iree_codegen.constraints target = , pipeline = None, + knobs = {} + dims(%arg0) { + ^bb0(%m: index): + } + return +} + +// ----- + +// KnobOp outside of ConstraintsOp. +func.func @knob_outside_constraints() { + // expected-error @+1 {{'iree_codegen.knob' op expects parent op 'iree_codegen.constraints'}} + %x = iree_codegen.knob "foo" : !smt.int + return +} + +// ----- + +// Constraints op: block arg count mismatch with problem_dims. +func.func @constraints_block_arg_mismatch(%arg0: index) { + // expected-error @+1 {{'iree_codegen.constraints' op expected 1 block arguments but got 2}} + iree_codegen.constraints target = , pipeline = None, + knobs = {} + dims(%arg0) { + ^bb0(%m: !smt.int, %extra: !smt.int): + } + return +} + +// ----- + +// Knob op: duplicate knob name. +func.func @duplicate_knob_name(%arg0: index) { + iree_codegen.constraints target = , pipeline = None, + knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">]} + dims(%arg0) { + ^bb0(%m: !smt.int): + // expected-note @+1 {{first occurrence here}} + %first = iree_codegen.knob "wg_m" : !smt.int + // expected-error @+1 {{'iree_codegen.knob' op duplicate knob name 'wg_m'}} + %second = iree_codegen.knob "wg_m" : !smt.int + } + return +} + +// ----- + +// Constraints op: too few block args for problem_dims. +func.func @constraints_block_arg_too_few(%arg0: index, %arg1: index) { + // expected-error @+1 {{'iree_codegen.constraints' op expected 2 block arguments but got 1}} + iree_codegen.constraints target = , pipeline = None, + knobs = {} + dims(%arg0, %arg1) { + ^bb0(%m: !smt.int): + } + return +} + +// ----- + +// Knob op: knob name not found in knobs dict. +func.func @knob_name_not_found(%arg0: index) { + iree_codegen.constraints target = , pipeline = None, + knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">]} + dims(%arg0) { + ^bb0(%m: !smt.int): + // expected-error @+1 {{'iree_codegen.knob' op knob name 'nonexistent' not found in knobs dict}} + %bad = iree_codegen.knob "nonexistent" : !smt.int + } + return +} + +// ----- + +// Knob op: bare string in knobs dict does not satisfy knob lookup. +func.func @string_attr_not_a_knob(%arg0: index) { + iree_codegen.constraints target = , pipeline = None, + knobs = {name = "wg_m"} + dims(%arg0) { + ^bb0(%m: !smt.int): + // expected-error @+1 {{'iree_codegen.knob' op knob name 'wg_m' not found in knobs dict}} + %bad = iree_codegen.knob "wg_m" : !smt.int + } + return +} + +// ----- + +// Constraints op: pipeline attr must be DispatchLoweringPassPipelineAttr or +// PipelineAttrInterface — a plain string attr is neither. +func.func @constraints_invalid_pipeline(%arg0: index) { + // expected-error @+1 {{'iree_codegen.constraints' op attribute 'pipeline' failed to satisfy constraint}} + iree_codegen.constraints target = , pipeline = "not_a_pipeline", + knobs = {} + dims(%arg0) { + ^bb0(%m: !smt.int): + } + return +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/roundtrip.mlir index 40e349f56c1f..9025e80c1365 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/roundtrip.mlir @@ -111,3 +111,104 @@ func.func private @workgroup_scope_attr_linearize() attributes { } // CHECK-LABEL: func.func private @workgroup_scope_attr_linearize() // CHECK-SAME: scope = #iree_codegen.workgroup_scope + +// ----- + +// Test constraints op with knobs and dims. +func.func @constraints_op(%arg0: index, %arg1: index) { + iree_codegen.constraints target = , pipeline = LLVMGPUVectorDistribute, + knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">, #iree_codegen.int_knob<"wg_n">]} + dims(%arg0, %arg1) { + ^bb0(%m: !smt.int, %n: !smt.int): + %wg_m = iree_codegen.knob "wg_m" : !smt.int + %wg_n = iree_codegen.knob "wg_n" : !smt.int + } + return +} +// CHECK-LABEL: func.func @constraints_op( +// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9_]+]]: index +// CHECK: iree_codegen.constraints target = , pipeline = LLVMGPUVectorDistribute, +// CHECK: knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">, #iree_codegen.int_knob<"wg_n">]} +// CHECK: dims(%[[M]], %[[N]]) +// CHECK: ^bb0(%{{.*}}: !smt.int, %{{.*}}: !smt.int): +// CHECK: iree_codegen.knob "wg_m" : !smt.int +// CHECK: iree_codegen.knob "wg_n" : !smt.int + +// ----- + +// Test constraints op with nested knobs (multiple dict groups) and SMT body. +func.func @constraints_op_with_smt_body(%arg0: index, %arg1: index) { + iree_codegen.constraints target = , pipeline = LLVMGPUVectorDistribute, + knobs = {reduction = [#iree_codegen.int_knob<"red_k">], workgroup = [#iree_codegen.int_knob<"wg_m">, #iree_codegen.int_knob<"wg_n">]} + dims(%arg0, %arg1) { + ^bb0(%m: !smt.int, %n: !smt.int): + %wg_m = iree_codegen.knob "wg_m" : !smt.int + %wg_n = iree_codegen.knob "wg_n" : !smt.int + %red_k = iree_codegen.knob "red_k" : !smt.int + %zero = smt.int.constant 0 + %wg_m_pos = smt.int.cmp gt %wg_m, %zero + smt.assert %wg_m_pos + } + return +} +// CHECK-LABEL: func.func @constraints_op_with_smt_body( +// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9_]+]]: index +// CHECK: iree_codegen.constraints target = , pipeline = LLVMGPUVectorDistribute, +// CHECK: knobs = {reduction = [#iree_codegen.int_knob<"red_k">], workgroup = [#iree_codegen.int_knob<"wg_m">, #iree_codegen.int_knob<"wg_n">]} +// CHECK: dims(%[[M]], %[[N]]) +// CHECK: ^bb0(%{{.*}}: !smt.int, %{{.*}}: !smt.int): +// CHECK: iree_codegen.knob "wg_m" : !smt.int +// CHECK: iree_codegen.knob "wg_n" : !smt.int +// CHECK: iree_codegen.knob "red_k" : !smt.int +// CHECK: %[[ZERO:.*]] = smt.int.constant 0 +// CHECK: %[[CMP:.*]] = smt.int.cmp gt +// CHECK: smt.assert %[[CMP]] + +// ----- + +// Test constraints op with empty dims. +func.func @constraints_op_empty_dims() { + iree_codegen.constraints target = , pipeline = None, + knobs = {} + dims() { + ^bb0: + } + return +} +// CHECK-LABEL: func.func @constraints_op_empty_dims( +// CHECK: iree_codegen.constraints target = , pipeline = None, +// CHECK: knobs = {} +// CHECK: dims() + +// Test constraints op with extra attributes (placed before the body). +func.func @constraints_op_with_attrs(%arg0: index) { + iree_codegen.constraints target = , pipeline = LLVMGPUTileAndFuse, + knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">]} + dims(%arg0) attributes {some_tag = "hello"} { + ^bb0(%m: !smt.int): + } + return +} +// CHECK-LABEL: func.func @constraints_op_with_attrs( +// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index +// CHECK: iree_codegen.constraints target = , pipeline = LLVMGPUTileAndFuse, +// CHECK: knobs = {workgroup = [#iree_codegen.int_knob<"wg_m">]} +// CHECK: dims(%[[M]]) attributes {some_tag = "hello"} +// CHECK: ^bb0(%{{.*}}: !smt.int): + +// Test constraints op with PipelineAttrInterface (pass_pipeline attr). +func.func @constraints_op_with_pass_pipeline(%arg0: index) { + iree_codegen.constraints target = , pipeline = #iree_codegen.pass_pipeline<"canonicalize">, + knobs = {} + dims(%arg0) { + ^bb0(%m: !smt.int): + } + return +} +// CHECK-LABEL: func.func @constraints_op_with_pass_pipeline( +// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index +// CHECK: iree_codegen.constraints target = , pipeline = #iree_codegen.pass_pipeline<"canonicalize">, +// CHECK: knobs = {} +// CHECK: dims(%[[M]]) diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 015c35e9efd5..7cd7b00b6937 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -121,6 +121,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPU", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SMTDialect", "@llvm-project//mlir:SPIRVDialect", "@llvm-project//mlir:SPIRVTransforms", "@llvm-project//mlir:ShapeDialect", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index e88f2f6be2c3..a6d2d3d709f0 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -102,6 +102,7 @@ iree_cc_library( MLIRSCFDialect MLIRSCFToGPU MLIRSCFTransforms + MLIRSMT MLIRSPIRVDialect MLIRSPIRVTransforms MLIRShapeDialect diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h index 4a6398bf7178..2c03bac03123 100644 --- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h @@ -44,6 +44,7 @@ #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -85,6 +86,7 @@ inline void registerMlirDialects(DialectRegistry ®istry) { scf::SCFDialect, quant::QuantDialect, ROCDL::ROCDLDialect, + smt::SMTDialect, spirv::SPIRVDialect, arm_neon::ArmNeonDialect, arm_sve::ArmSVEDialect, From 3dccac06ead28c74a6958ec4b4fbb185f1ee7ef9 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sat, 7 Mar 2026 08:59:54 -0500 Subject: [PATCH 2/2] Simplify --- .../Dialect/Codegen/IR/IREECodegenOps.cpp | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp index 496cb93ec2f4..1e8589de94b4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp @@ -480,26 +480,19 @@ void WorkgroupCountHintOp::build(OpBuilder &builder, OperationState &state, /// Recursively check whether `name` appears as a knob name in `attr`. /// Checks IntKnobAttr names and recurses into DictionaryAttr/ArrayAttr. static bool hasKnobName(Attribute attr, StringRef name) { - if (auto intKnob = dyn_cast(attr)) { - return intKnob.getName().getValue() == name; - } - if (auto dictAttr = dyn_cast(attr)) { - for (NamedAttribute entry : dictAttr) { - if (hasKnobName(entry.getValue(), name)) { - return true; - } - } - return false; - } - if (auto arrayAttr = dyn_cast(attr)) { - for (Attribute element : arrayAttr) { - if (hasKnobName(element, name)) { - return true; - } - } - return false; - } - return false; + return TypeSwitch(attr) + .Case([&](IntKnobAttr knob) { return knob.getName().getValue() == name; }) + .Case([&](DictionaryAttr dict) { + return llvm::any_of(dict, [&](NamedAttribute entry) { + return hasKnobName(entry.getValue(), name); + }); + }) + .Case([&](ArrayAttr array) { + return llvm::any_of(array, [&](Attribute element) { + return hasKnobName(element, name); + }); + }) + .Default(false); } LogicalResult ConstraintsOp::verify() {