Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/Conversion/ONNXToLinalg/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ namespace onnx_mlir {
struct ONNXMatMulOpLoweringToLinalg : public OpRewritePattern<ONNXMatMulOp> {
ONNXMatMulOpLoweringToLinalg(
MLIRContext *ctx, const std::string &linalgOps, bool useLinalgPath)
: OpRewritePattern<ONNXMatMulOp>(ctx), linalgOps(linalgOps),
useLinalgPath(useLinalgPath) {}
: OpRewritePattern<ONNXMatMulOp>(ctx), useLinalgPath(useLinalgPath),
linalgOpsMatcher(
linalgOps.empty()
? nullptr
: std::make_unique<EnableByRegexOption>(false, linalgOps)) {}

LogicalResult matchAndRewrite(
ONNXMatMulOp matMulOp, PatternRewriter &rewriter) const final {
// Check if this operation should be converted to Linalg based on
// --linalg-ops option
Operation *op = matMulOp.getOperation();
if (!shouldConvertToLinalg(op, linalgOps, useLinalgPath)) {
if (!shouldConvertToLinalg(op, linalgOpsMatcher.get(), useLinalgPath)) {
return rewriter.notifyMatchFailure(
matMulOp, "operation not selected for Linalg conversion");
}
Expand Down Expand Up @@ -97,8 +100,8 @@ struct ONNXMatMulOpLoweringToLinalg : public OpRewritePattern<ONNXMatMulOp> {
}

private:
std::string linalgOps;
bool useLinalgPath;
mutable std::unique_ptr<EnableByRegexOption> linalgOpsMatcher;
};

void populateLoweringONNXMatMulOpToLinalgPattern(RewritePatternSet &patterns,
Expand Down
20 changes: 9 additions & 11 deletions src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,15 @@ namespace onnx_mlir {
// --use-linalg-path is enabled.
// Note: When convert-onnx-to-linalg pass is explicitly run (e.g., via
// onnx-mlir-opt), we default to converting operations if no options are set.
inline bool shouldConvertToLinalg(
mlir::Operation *op, const std::string &linalgOps, bool useLinalgPath) {
// When convert-onnx-to-linalg pass is explicitly run (e.g., via
// onnx-mlir-opt), we default to converting all operations unless linalgOps
// is explicitly set
if (linalgOps.empty()) {
// If linalgOps is not specified, check useLinalgPath flag
// If useLinalgPath is true, convert all operations to Linalg
// Otherwise, default to true for onnx-mlir-opt usage (when pass is
// explicitly run) Note: In onnx-mlir-opt, useLinalgPath may not be
// initialized, so we default to true
// The linalgOpsMatcher parameter should be a pointer to an EnableByRegexOption
// instance that is thread-safe (each pattern instance should have its own).
// Note: linalgOpsMatcher is non-const because isEnabled() modifies internal
// cache.
inline bool shouldConvertToLinalg(mlir::Operation *op,
EnableByRegexOption *linalgOpsMatcher, bool useLinalgPath) {
// When linalgOpsMatcher is null or empty, default to converting all
// operations
if (!linalgOpsMatcher) {
return true;
}

Expand Down
Loading