Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def __init__(self, repo_map: Dict[str, str]):
"@llvm-project//mlir:MlirLspServerLib": ["MLIRLspServerLib"],
"@llvm-project//mlir:MlirTableGenMain": ["MLIRTableGen"],
"@llvm-project//mlir:MlirOptLib": ["MLIROptLib"],
"@llvm-project//mlir:CAPISMT": [
"MLIRCAPISMT",
"MLIRCAPIExportSMTLIB",
],
"@llvm-project//mlir:SMTDialect": ["MLIRSMT"],
"@llvm-project//mlir:TargetSMTLIB": ["MLIRExportSMTLIB"],
"@llvm-project//mlir:VectorOps": ["MLIRVector"],
# StableHLO.
"@stablehlo//:chlo_ops": [
Expand Down
1 change: 1 addition & 0 deletions compiler/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ set(_SOURCE_COMPONENTS
MLIRPythonSources.Dialects.rocdl
MLIRPythonSources.Dialects.scf
MLIRPythonSources.Dialects.shape
MLIRPythonSources.Dialects.smt
MLIRPythonSources.Dialects.structured_transform
MLIRPythonSources.Dialects.tensor
MLIRPythonSources.Dialects.tosa
Expand Down
1 change: 1 addition & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def decorator_builder(func):
rocdl,
scf,
shape,
smt,
tensor,
tosa,
transform,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:CAPILLVM",
"@llvm-project//mlir:CAPILinalg",
"@llvm-project//mlir:CAPIPDL",
"@llvm-project//mlir:CAPISMT",
"@llvm-project//mlir:CAPITarget",
"@llvm-project//mlir:CAPITransformDialect",
"@llvm-project//mlir:CAPITransformDialectTransforms",
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/iree/compiler/API/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ iree_cc_library(
IREEDialectsCAPI
MLIRCAPIAMDGPU
MLIRCAPIDebug
MLIRCAPIExportSMTLIB
MLIRCAPIGPU
MLIRCAPIIR
MLIRCAPIInterfaces
MLIRCAPILLVM
MLIRCAPILinalg
MLIRCAPIPDL
MLIRCAPISMT
MLIRCAPITarget
MLIRCAPITransformDialect
MLIRCAPITransformDialectTransforms
Expand Down Expand Up @@ -80,6 +82,8 @@ set(_EXPORT_OBJECT_LIBS
obj.MLIRCAPILLVM
obj.MLIRCAPILinalg
obj.MLIRCAPIPDL
obj.MLIRCAPISMT
obj.MLIRCAPIExportSMTLIB
obj.MLIRCAPITarget
obj.MLIRCAPITransforms
obj.MLIRCAPITransformDialect
Expand Down Expand Up @@ -194,9 +198,11 @@ target_link_libraries(iree_compiler_API_SharedImpl PRIVATE
${_EXPORT_OBJECT_DEPS}
)

# Link MLIRTargetLLVMIRImport directly since it is not exported as an object library.
# Link MLIRTargetLLVMIRImport and MLIRExportSMTLIB directly since they are
# not exported as object libraries.
target_link_libraries(iree_compiler_API_SharedImpl PRIVATE
MLIRTargetLLVMIRImport
MLIRExportSMTLIB
)

# If not using sanitizers, ask linkers to error on undefined symbols.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ iree_td_library(
"@llvm-project//mlir:LinalgOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SCFTdFiles",
"@llvm-project//mlir:SMTTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@llvm-project//mlir:TilingInterfaceTdFiles",
"@llvm-project//mlir:VectorInterfacesTdFiles",
Expand Down Expand Up @@ -118,6 +119,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:SMTDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TilingInterface",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ iree_cc_library(
MLIRPass
MLIRSCFDialect
MLIRSCFTransforms
MLIRSMT
MLIRSupport
MLIRTensorDialect
MLIRTilingInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace mlir::iree_compiler::IREE::Codegen {
/// Parses either a DispatchLoweringPassPipeline enum keyword (e.g.,
/// `CPUDefault`) or a generic attribute implementing PipelineAttrInterface
/// (e.g., `#iree_codegen.pass_pipeline<"canonicalize">`).
static ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) {
ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) {
StringRef keyword;
SMLoc loc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalKeyword(&keyword))) {
Expand All @@ -53,7 +53,7 @@ static ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result) {

/// Prints DispatchLoweringPassPipelineAttr as a bare keyword and other
/// attributes (e.g., PipelineAttrInterface impls) via the generic printer.
static void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr) {
void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr) {
if (auto enumAttr =
dyn_cast<DispatchLoweringPassPipelineAttr>(pipelineAttr)) {
printer << stringifyEnum(enumAttr.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ bool shouldSetTunerAttributes();
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h.inc"
// clang-format on

namespace mlir::iree_compiler::IREE::Codegen {

/// Parses either a DispatchLoweringPassPipeline enum keyword or a generic
/// attribute implementing PipelineAttrInterface.
ParseResult parsePipelineAttr(AsmParser &parser, Attribute &result);
inline ParseResult parsePipelineAttr(OpAsmParser &parser, Attribute &result) {
return parsePipelineAttr(static_cast<AsmParser &>(parser), result);
}

/// Prints DispatchLoweringPassPipelineAttr as a bare keyword and other
/// attributes via the generic printer.
void printPipelineAttr(AsmPrinter &printer, Attribute pipelineAttr);
inline void printPipelineAttr(OpAsmPrinter &printer, Operation *,
Attribute pipelineAttr) {
printPipelineAttr(printer, pipelineAttr);
}

} // namespace mlir::iree_compiler::IREE::Codegen

namespace mlir::iree_compiler {
//===----------------------------------------------------------------------===//
// Constant names.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,4 +732,22 @@ def IREECodegen_RootOpAttr : AttrDef<IREECodegen_Dialect, "RootOp"> {
let assemblyFormat = "`<` `set` `=` $set `>`";
}

//===---------------------------------------------------------------------===//
// iree_codegen.int_knob
//===---------------------------------------------------------------------===//

def IREECodegen_IntKnobAttr : AttrDef<IREECodegen_Dialect, "IntKnob"> {
let mnemonic = "int_knob";
let summary = "Integer-valued tunable knob placeholder.";
let description = [{
Represents a named placeholder for an integer tunable parameter in a
constraints knobs dictionary. During constraint generation, these appear
in tiling arrays (workgroup, reduction, thread), workgroup_size, and
subgroup_size positions. The name matches the corresponding
`iree_codegen.knob` op name.
}];
let parameters = (ins "StringAttr":$name);
let assemblyFormat = "`<` $name `>`";
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"
// clang-format off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def IREECodegen_Dialect : Dialect {
// We use linalg.iterator_type to avoid clashing definitions for the
// enum attribute that wraps mlir::utils::IteratorType, so we need to be
// able to parse it.
let dependentDialects = ["::mlir::linalg::LinalgDialect"];
let dependentDialects = [
"::mlir::linalg::LinalgDialect",
"::mlir::smt::SMTDialect",
];
let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
let hasOperationAttrVerify = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SMT/IR/SMTTypes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineMap.h"
Expand All @@ -22,6 +25,23 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"

// Custom parse/print helper for the knobs dictionary in constraints op.
// Prints `knobs = { ... }` on its own line with newlines before and after.
static mlir::ParseResult parseKnobsDictionary(mlir::OpAsmParser &parser,
mlir::DictionaryAttr &attr) {
if (parser.parseKeyword("knobs") || parser.parseEqual()) {
return mlir::failure();
}
return parser.parseAttribute(attr);
}
static void printKnobsDictionary(mlir::OpAsmPrinter &p, mlir::Operation *,
mlir::DictionaryAttr attr) {
p.printNewline();
p << " knobs = ";
p.printAttributeWithoutType(attr);
p.printNewline();
}

// clang-format off
#define GET_OP_CLASSES
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp.inc" // IWYU pragma: keep
Expand Down Expand Up @@ -452,3 +472,73 @@ void WorkgroupCountHintOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, dynamicSizes,
builder.getDenseI64ArrayAttr(staticSizes));
}

//===----------------------------------------------------------------------===//
// ConstraintsOp
//===----------------------------------------------------------------------===//

/// Recursively check whether `name` appears as a knob name in `attr`.
/// Checks IntKnobAttr names and recurses into DictionaryAttr/ArrayAttr.
static bool hasKnobName(Attribute attr, StringRef name) {
return TypeSwitch<Attribute, bool>(attr)
.Case([&](IntKnobAttr knob) { return knob.getName().getValue() == name; })
.Case([&](DictionaryAttr dict) {
return llvm::any_of(dict, [&](NamedAttribute entry) {
return hasKnobName(entry.getValue(), name);
});
})
.Case([&](ArrayAttr array) {
return llvm::any_of(array, [&](Attribute element) {
return hasKnobName(element, name);
});
})
.Default(false);
}

LogicalResult ConstraintsOp::verify() {
Block &block = getBody().front();

// Check block arg count matches problem_dims count.
if (block.getNumArguments() != getProblemDims().size()) {
return emitOpError("expected ")
<< getProblemDims().size() << " block arguments but got "
<< block.getNumArguments();
}

// Check all block args are !smt.int.
smt::IntType smtIntType = smt::IntType::get(getContext());
for (auto [i, arg] : llvm::enumerate(block.getArguments())) {
if (arg.getType() != smtIntType) {
return emitOpError("block argument #")
<< i << " must be !smt.int but got " << arg.getType();
}
}

// Verify knob ops: check names exist in the dict and reject duplicates.
// Note that we considered using SymbolTable for uniqueness, but the knobs
// dictionary contains attributes (not ops), so we'd still need custom
// verification for dictionary <--> KnobOp correspondence.
// Rejecting duplicates is not just pedantic -- when this op is lowered to
// SMT, each KnobOp becomes an `smt.declare_const`. The SMT dialect creates
// a fresh symbolic constant per declaration regardless of the name string,
// so two KnobOps with the same name would silently introduce two independent
// solver variables where one was intended, producing incorrect constraints.
DictionaryAttr knobs = getKnobsAttr();
llvm::StringMap<Location> seenKnobs;
for (auto knobOp : block.getOps<KnobOp>()) {
Comment thread
bangtianliu marked this conversation as resolved.
auto [it, inserted] =
seenKnobs.try_emplace(knobOp.getName(), knobOp.getLoc());
if (!inserted) {
InFlightDiagnostic diag = knobOp.emitOpError("duplicate knob name '")
<< knobOp.getName() << "'";
diag.attachNote(it->second) << "first occurrence here";
return diag;
}
if (!hasKnobName(knobs, knobOp.getName())) {
return knobOp.emitOpError("knob name '")
<< knobOp.getName() << "' not found in knobs dict";
}
}

return success();
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SMT/IR/SMTTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
Loading
Loading