Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
d7a8ca0
Add ONNXToLinalg conversion pass with MatMul support
Dec 4, 2025
e7ff618
Add ONNX to Linalg lowering pipeline
Dec 8, 2025
8aed3d4
Add bufferization to ONNX to Linalg pipeline
Dec 8, 2025
dc1c1a3
Add --use-linalg-path flag to choose between Krnl and Linalg pipelines
Dec 8, 2025
1f01d33
Split addONNXToLinalgPasses into addLinalgToAffinePasses and addLinal…
Dec 8, 2025
5ffba8a
Refactor addPasses: translate comments to English and clarify pass ex…
Dec 8, 2025
0a89257
Fix pass execution order: ensure addONNXToMLIRPasses is called for Li…
Dec 8, 2025
75eb929
Refactor Linalg lowering pipeline: simplify to use pm.addPass only
Dec 8, 2025
73e7850
Analyze PassManager execution mechanism and fix pass ordering
Dec 8, 2025
f1f8f6f
Try different pass ordering: convert-linalg-to-loops before bufferiza…
Dec 8, 2025
700b832
Register BufferizableOpInterface for arith, linalg, and tensor dialects
Dec 8, 2025
3ec89c6
Register BufferizableOpInterface for func dialect
Dec 8, 2025
5446d2c
Add one-shot-bufferize to Linalg pipeline and fix pass ordering
Dec 8, 2025
ebf6a31
Refactor Linalg lowering pipeline to match Krnl structure
Dec 8, 2025
e87ddfd
Add comparison test between Krnl and Linalg pipelines
Dec 8, 2025
f18f6f5
Remove onnx.EntryPoint before LLVM conversion in Linalg path
Dec 8, 2025
9c87e99
Apply coding practices: Use C++ style casting
Dec 8, 2025
75f7ccc
Remove debug logs and apply clang-format
Dec 8, 2025
c2c2108
Apply clang-format to ConvertONNXToLinalg.cpp
Dec 8, 2025
9d7dd17
Refactor Linalg to LLVM pipeline to use createConvertKrnlToLLVMPass
Dec 9, 2025
e75cb3a
Add analysis of signature generation issue in Linalg pipeline
Dec 10, 2025
14d7027
Fix ONNXEntryPointLowering pattern in Linalg pipeline
Dec 10, 2025
497560c
Translate comments in driver.cpp to English
Dec 10, 2025
99c4896
Add test driver for Linalg pipeline MatMul execution
Dec 10, 2025
61a1159
Refactor ONNXEntryPointLowering to use existing implementation
Dec 10, 2025
82dae77
Apply clang-format to modified files
Dec 10, 2025
393b29d
Clean up conversion directory
Dec 10, 2025
b02d925
Remove LINALG_PIPELINE_PASSES.md documentation file
Dec 10, 2025
1aea799
Clean up directory structure
Dec 10, 2025
6cfa379
Clang Format
Dec 12, 2025
5718b12
Support empty inputs to the compiled model and remove input arguments…
tungld Dec 4, 2025
bba6ea9
Make dynamic dimension analysis for Reshape more general (#3341)
tungld Dec 5, 2025
c9f4422
added comment to script and doc (#3312)
AlexandreEichenberger Dec 5, 2025
c5c3421
Fix a bug in saving a json file and a bug in onnx to zhigh for quanti…
tungld Dec 5, 2025
d8336ee
Make verifyInputTensor true by default (#3079)
chentong319 Dec 5, 2025
9912844
fix osx build for deprecation of git image (#3348)
Sunny-Anand Dec 10, 2025
9a30927
Add ONNXToLinalg conversion pass with MatMul support (#3343)
kimm240 Dec 10, 2025
52a0a96
Clang Format
Dec 12, 2025
fcb1162
delete test file that doesn't have RUN script
Dec 12, 2025
206ea6b
Chore: trigger ci
Dec 15, 2025
75fb7ad
Chore: trigger ci
Dec 15, 2025
30f8bda
Chore: trigger ci
Dec 15, 2025
779c1b9
Chore: trigger ci
Dec 15, 2025
8859336
Merge branch 'main' into feature/linalg-to-llvm-pipeline
kimm240 Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ add_onnx_mlir_library(OMCompilerPasses
${OMLibs}
OMCompilerOptions
MLIRAffineTransforms
MLIRArithToLLVM
MLIRBufferizationPipelines
MLIRBufferizationToMemRef
MLIRIR
Expand Down
14 changes: 14 additions & 0 deletions src/Compiler/CompilerDialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"

#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

using namespace mlir;
Expand All @@ -43,6 +48,7 @@ DialectRegistry registerDialects(ArrayRef<accel::Accelerator::Kind> accels) {
registry.insert<shape::ShapeDialect>();
registry.insert<math::MathDialect>();
registry.insert<memref::MemRefDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<ONNXDialect>();
registry.insert<KrnlDialect>();
registry.insert<cf::ControlFlowDialect>();
Expand All @@ -60,6 +66,14 @@ DialectRegistry registerDialects(ArrayRef<accel::Accelerator::Kind> accels) {
memref::registerAllocationOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);

// Register BufferizableOpInterface for one-shot bufferization (needed for
// Linalg path)
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);

return registry;
}

Expand Down
15 changes: 11 additions & 4 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ ProfileIRs profileIR; // onnx-mlir only
OptReport optReport; // onnx-mlir only
bool enableTiming; // onnx-mlir only
bool enableBoundCheck; // onnx-mlir only
bool split_input_file; // onnx-mlir-opt only
bool verify_diagnostics; // onnx-mlir-opt only
bool verify_passes; // onnx-mlir-opt only
bool allowUnregisteredDialects; // onnx-mlir-opt only
bool useLinalgPath; // onnx-mlir only

bool split_input_file; // onnx-mlir-opt only
bool verify_diagnostics; // onnx-mlir-opt only
bool verify_passes; // onnx-mlir-opt only
bool allowUnregisteredDialects; // onnx-mlir-opt only

// Category for common options shared between onnx-mlir and onnx-mlir-opt.
llvm::cl::OptionCategory OnnxMlirCommonOptions("common options",
Expand Down Expand Up @@ -633,6 +635,11 @@ static llvm::cl::opt<bool, true> verifyInputTensorsOpt("verifyInputTensors",
llvm::cl::location(verifyInputTensors), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> useLinalgPathOpt("use-linalg-path",
llvm::cl::desc("Use Linalg lowering path instead of Krnl (default=false)."),
llvm::cl::location(useLinalgPath), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> allowSortingOpt("allowSorting",
llvm::cl::desc("Perform topological sort on onnx graph."),
llvm::cl::location(allowSorting), llvm::cl::init(true),
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ extern OptReport optReport; // onnx-mlir only
extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
extern bool debugTestCompilerOpt; // onnx-mlir only
extern bool useLinalgPath; // onnx-mlir only

extern bool split_input_file; // onnx-mlir-opt only
extern bool verify_diagnostics; // onnx-mlir-opt only
Expand Down
169 changes: 161 additions & 8 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,33 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"

#include "src/Builder/ModelInputShaper.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/DisposableGarbageCollector.hpp"
#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/VectorMachineSupport.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"

using namespace mlir;
Expand Down Expand Up @@ -224,6 +232,117 @@ void addKrnlToAffinePasses(mlir::PassManager &pm) {
onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel));
}

void addONNXToLinalgPasses(mlir::PassManager &pm) {
// Convert ONNX operations to Linalg dialect
// Similar to addONNXToKrnlPasses for Krnl path
// Note: This assumes addONNXToMLIRPasses has been called first to:
// - Replace ONNXReturnOp with func::ReturnOp (createStandardFuncReturnPass)
// - Clean dead code (createSymbolDCEPass)
// - Other preprocessing passes

pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvertONNXToLinalg());

// Convert ONNXEntryPointOp to KrnlEntryPointOp
// This MUST be done BEFORE bufferization because getSignature() needs
// tensor types, not memref types. After bufferization, function signatures
// are converted to memref types which cannot be properly serialized.
// This uses the same ONNXEntryPointLowering pattern as the Krnl pipeline
// to ensure consistent signature generation.
struct ConvertONNXEntryPointToKrnlPass
: public mlir::PassWrapper<ConvertONNXEntryPointToKrnlPass,
mlir::OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
ConvertONNXEntryPointToKrnlPass)
void runOnOperation() override {
mlir::ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);

// Use the existing ONNXEntryPointLowering pattern from
// ConvertONNXToKrnl.cpp This ensures identical signature generation as
// Krnl pipeline
populateLoweringONNXEntryPointOpPattern(patterns, context);

// Apply patterns greedily
mlir::GreedyRewriteConfig config;
if (failed(mlir::applyPatternsGreedily(
module, std::move(patterns), config))) {
signalPassFailure();
}
}
StringRef getArgument() const override {
return "convert-onnx-entry-point-to-krnl";
}
StringRef getDescription() const override {
return "Convert onnx.EntryPoint to krnl.EntryPoint for entry point "
"function generation (same as Krnl pipeline)";
}
};
pm.addPass(std::make_unique<ConvertONNXEntryPointToKrnlPass>());

// One-shot bufferization (Tensor → Memref)
// This must be a module-level pass to handle function boundaries
bufferization::OneShotBufferizePassOptions bufferizeOptions;
bufferizeOptions.bufferizeFunctionBoundaries = true;
pm.addPass(bufferization::createOneShotBufferizePass(bufferizeOptions));

// An additional pass of canonicalization is helpful after conversion
pm.addPass(mlir::createCanonicalizerPass());
}

void addLinalgToAffinePasses(mlir::PassManager &pm) {
// Convert Linalg operations to Affine/SCF loops
// Similar to addKrnlToAffinePasses for Krnl path
auto &funcPM = pm.nest<func::FuncOp>();

// 1. Linalg → Loops (creates structured control-flow loops: affine.for,
// scf.for)
funcPM.addPass(mlir::createConvertLinalgToLoopsPass());

// 2. Buffer management (MUST be before convert-scf-to-cf)
// buildBufferDeallocationPipeline requires structured control-flow loops
// which are created by convert-linalg-to-loops above
// convert-scf-to-cf removes structured loops, so buffer management must come
// first
funcPM.addPass(bufferization::createBufferLoopHoistingPass());
bufferization::BufferDeallocationPipelineOptions bufferDeallocOptions;
mlir::bufferization::buildBufferDeallocationPipeline(
funcPM, bufferDeallocOptions);
funcPM.addPass(mlir::bufferization::createOptimizeAllocationLivenessPass());
funcPM.addPass(mlir::createConvertBufferizationToMemRefPass());

// 3. Lower to Affine/SCF (after buffer management)
funcPM.addPass(mlir::createLowerAffinePass());
funcPM.addPass(mlir::createSCFToControlFlowPass());
}

void addLinalgToLLVMPasses(mlir::PassManager &pm, std::string outputNameNoExt) {
// Convert remaining operations to LLVM dialect
// Similar to addKrnlToLLVMPasses for Krnl path
// Note: onnx.EntryPoint is already converted to krnl.EntryPoint in
// addONNXToLinalgPasses, so we can use createConvertKrnlToLLVMPass()
// to generate runtime functions (omQueryEntryPoints, omInputSignature,
// omOutputSignature, etc.)

// This pass handles:
// 1. Entry point preprocessing (PostfixEntrypointNames,
// removeUnhandledParamAttrs)
// 2. Runtime information collection (recordInputOutputMemRefTypes,
// hasSingleEntryPoint, determineOwnershipForOutputOMTensors)
// 3. KrnlEntryPointOp → LLVM conversion (dynamic entry point functions,
// OMTensor conversion, accelerator initialization, signature recording)
// 4. Runtime function generation (omQueryEntryPoints, omInputSignature,
// omOutputSignature)
// 5. Other features (constants file storage, C wrapper, .lrodata section)
pm.addPass(krnl::createConvertKrnlToLLVMPass(verifyInputTensors,
/*useLRODATA=*/(modelSize == ModelSize::large),
/*storeConstantsToFile=*/storeConstantsToFile,
constantsToFileSingleThreshold, constantsToFileTotalThreshold,
outputNameNoExt, enableParallel));
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
pm.addPass(mlir::createCanonicalizerPass());
}

void addKrnlToLLVMPasses(
mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE) {
if (enableCSE)
Expand Down Expand Up @@ -330,19 +449,53 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
EmissionTargetType emissionTarget, std::string outputNameNoExt) {
InputIRLevelType inputIRLevel = determineInputIRLevel(module);

if (inputIRLevel <= ONNXLevel && emissionTarget >= EmitONNXIR)
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty());
// Step 1: Convert ONNX to intermediate representation (Krnl or Linalg)
if (inputIRLevel <= ONNXLevel) {
// Always call addONNXToMLIRPasses first for preprocessing (ONNXReturnOp ->
// func::ReturnOp, etc.) This is needed for both Krnl and Linalg paths
bool shouldCallONNXToMLIR = emissionTarget >= EmitONNXIR ||
(emissionTarget >= EmitMLIR && useLinalgPath);
if (shouldCallONNXToMLIR) {
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty());
}

if (useLinalgPath) {
// Linalg path: Convert ONNX to Linalg (after preprocessing)
addONNXToLinalgPasses(pm);
}
}

// Step 2: Lower to Affine dialect (for EmitMLIR)
if (emissionTarget >= EmitMLIR) {
if (inputIRLevel <= ONNXLevel)
addONNXToKrnlPasses(
pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats);
if (inputIRLevel <= MLIRLevel)
if (inputIRLevel <= ONNXLevel) {
if (useLinalgPath) {
// Linalg path: Lower Linalg to Affine
addLinalgToAffinePasses(pm);
} else {
// Krnl path: Convert ONNX to Krnl, then Krnl to Affine
addONNXToKrnlPasses(
pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats);
}
}
// For Krnl path: Lower Krnl to Affine (when input is already at MLIR level)
if (inputIRLevel <= MLIRLevel && !useLinalgPath) {
addKrnlToAffinePasses(pm);
}
}

if (inputIRLevel <= LLVMLevel && emissionTarget >= EmitLLVMIR)
addKrnlToLLVMPasses(pm, outputNameNoExt, /*enableCSE=*/true);
// Step 3: Lower to LLVM dialect (for EmitLLVMIR)
if (emissionTarget >= EmitLLVMIR) {
if (useLinalgPath) {
// Linalg path: Lower remaining operations to LLVM
// Uses createConvertKrnlToLLVMPass() to generate runtime functions
// since onnx.EntryPoint is already converted to krnl.EntryPoint
addLinalgToLLVMPasses(pm, outputNameNoExt);
} else {
// Krnl path: Lower Krnl to LLVM
if (inputIRLevel <= LLVMLevel)
addKrnlToLLVMPasses(pm, outputNameNoExt, /*enableCSE=*/true);
}
}
}

} // namespace onnx_mlir
3 changes: 3 additions & 0 deletions src/Compiler/CompilerPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string ONNXOpsStatFilename);
void addKrnlToAffinePasses(mlir::PassManager &pm);
void addONNXToLinalgPasses(mlir::PassManager &pm);
void addLinalgToAffinePasses(mlir::PassManager &pm);
void addLinalgToLLVMPasses(mlir::PassManager &pm, std::string outputNameNoExt);
void addKrnlToLLVMPasses(
mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE);
InputIRLevelType determineInputIRLevel(
Expand Down
7 changes: 6 additions & 1 deletion src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ std::map<std::string, std::string> ONNXEntryPointLowering::typeMap = {
{std::string(" ui16 "), std::string(" \"ui16\" ")},
{std::string(" ui8 "), std::string(" \"ui8\" ")}};

void populateLoweringONNXEntryPointOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXEntryPointLowering>(ctx);
}

void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx, DimAnalysis *dimAnalysis,
bool enableTiling, bool enableSIMD, bool enableParallel,
Expand Down Expand Up @@ -287,7 +292,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
populateLoweringONNXSequenceInsertOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXSequenceLengthOpPattern(patterns, typeConverter, ctx);
// Entry point
patterns.insert<ONNXEntryPointLowering>(ctx);
populateLoweringONNXEntryPointOpPattern(patterns, ctx);
// Additional
populateLoweringONNXCustomOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXLayoutTransformOpPattern(patterns, typeConverter, ctx, enableParallel);
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ void populateLoweringONNXSequenceInsertOpPattern(
void populateLoweringONNXSequenceLengthOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);

// Entry point lowering
void populateLoweringONNXEntryPointOpPattern(
mlir::RewritePatternSet &, mlir::MLIRContext *);

// `Tensor` directory methods:
void populateLoweringONNXArgMinMaxOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
Expand Down
Loading
Loading