Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ bool allowUnregisteredDialects; // onnx-mlir-opt only

bool useLinalgPath; // onnx-mlir only
std::string linalgOps; // common for both onnx-mlir and onnx-mlir-opt
std::vector<std::string> discardableAttrs; // onnx-mlir only

// Category for common options shared between onnx-mlir and onnx-mlir-opt.
llvm::cl::OptionCategory OnnxMlirCommonOptions("common options",
Expand Down Expand Up @@ -749,6 +750,18 @@ static llvm::cl::opt<bool, true> disableConstantPropOpt("disable-constant-prop",
llvm::cl::location(disableConstantProp), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>>
discardableAttrsOpt("discardable-attrs",
llvm::cl::desc(
"Specify attribute names to mark as discardable.\n"
"Discardable attributes are prefixed with '_.' and can be "
"safely removed without affecting semantics.\n"
"Multiple attribute names can be specified.\n"
"Default: onnx_node_name"),
llvm::cl::location(discardableAttrs),
llvm::cl::list_init<std::string>({"onnx_node_name"}),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<uint64_t, true> compilation_num_threads("j",
llvm::cl::desc("Use <int> threads for compilation. The default value is "
"0, which spawns threads for all available CPUs.\n"),
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
extern bool debugTestCompilerOpt; // onnx-mlir only
extern bool useLinalgPath; // onnx-mlir only
extern std::string linalgOps; // common for both onnx-mlir and onnx-mlir-opt
extern std::vector<std::string> discardableAttrs; // onnx-mlir only
extern std::string linalgOps; // common for both

extern bool split_input_file; // onnx-mlir-opt only
extern bool verify_diagnostics; // onnx-mlir-opt only
Expand Down
9 changes: 9 additions & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
addLinalgToAffinePasses(pm);
}

// Convert attributes to discardable form if specified
if (!discardableAttrs.empty()) {
onnx_mlir::ConvertAttrToDiscardableOptions options;
llvm::SmallVector<std::string, 4> attrNames(
discardableAttrs.begin(), discardableAttrs.end());
options.attrNames = std::move(attrNames);
pm.addPass(onnx_mlir::createConvertAttrToDiscardable(options));
}

if (enableCSE)
// Eliminate common sub-expressions before lowering to Krnl.
// TODO: enable this by default when we make sure it works flawlessly.
Expand Down
4 changes: 4 additions & 0 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ std::unique_ptr<mlir::Pass> createConvertONNXToLinalg(
#define GEN_PASS_DECL_BUFFEROMPLOOPHOISTINGPASS
#include "src/Transform/Passes.h.inc"

/// Pass for converting attributes to discardable form.
#define GEN_PASS_DECL_CONVERTATTRTODISCARDABLE
#include "src/Transform/Passes.h.inc"

// The function registerTransformsPasses() is generated from Passes.td and used
// to register th pass for onnx-mlir-opt. Different Passes.td will generate the
// same name function. They have to be put into different name space to be
Expand Down
13 changes: 12 additions & 1 deletion src/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ add_onnx_mlir_library(OMInstrument

add_onnx_mlir_library(BufferOMPLoopHoist
BufferOMPLoopHoist.cpp

DEPENDS
OMTransformsPassIncGen

Expand All @@ -66,3 +66,14 @@ add_onnx_mlir_library(BufferOMPLoopHoist
MLIRMemRefDialect
OMOptionUtils
)

add_onnx_mlir_library(OMConvertAttrToDiscardable
ConvertAttrToDiscardable.cpp

DEPENDS
OMTransformsPassIncGen

LINK_LIBS PUBLIC
MLIRPass
OMSupport
)
69 changes: 69 additions & 0 deletions src/Transform/ConvertAttrToDiscardable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===-- ConvertAttrToDiscardable.cpp - Convert Attributes to Discardable -===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a pass that converts attributes to discardable form
// by prefixing them with '_.', based on the provided attribute names.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/raw_ostream.h"

#include "src/Pass/Passes.hpp"

using namespace mlir;

namespace onnx_mlir {

#define GEN_PASS_DEF_CONVERTATTRTODISCARDABLE
#include "src/Transform/Passes.h.inc"

/*!
* This pass converts attributes to discardable form
*/

class ConvertAttrToDiscardable
: public impl::ConvertAttrToDiscardableBase<ConvertAttrToDiscardable> {
using Base::Base; // Inherit generated constructors (so options get wired)

public:
void runOnOperation() override {
Operation *rootOp = getOperation();

// Early return if no attribute names provided
if (attrNames.empty()) {
return;
}

// Walk through all operations in the module/function
rootOp->walk([&](Operation *op) {
// Process each attribute name from the option
for (const std::string &attrName : attrNames) {
// Check if the operation has this attribute
if (op->hasAttr(attrName)) {
// Get the existing attribute
Attribute attr = op->getAttr(attrName);

// Remove the original non-discardable version
op->removeAttr(attrName);

// Since the CSE does not ignore the discardable attribute,
// the attribute is simply removed for current implementation.
// Set as a discardable attribute
// op->setDiscardableAttr(attrName, attr);
}
}
});
}
};

} // namespace onnx_mlir
12 changes: 12 additions & 0 deletions src/Transform/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,17 @@ def BufferOMPLoopHoistingPass : Pass<"buffer-omploop-hoisting", "func::FuncOp">
}];
}

def ConvertAttrToDiscardable : Pass<"convert-attr-to-discardable"> {
let summary = "Convert attributes to discardable form";
let description = [{
This pass converts specified attributes to discardable form by
prefixing them with '_.', based on the provided attribute names.
}];

let options = [
ListOption<"attrNames", "attr-names", "std::string",
"List of attribute names to mark as discardable">
];
}

#endif
34 changes: 34 additions & 0 deletions test/mlir/onnx/onnx_set_discardable_attr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: onnx-mlir-opt --convert-attr-to-discardable=attr-names=onnx_node_name --cse %s -split-input-file | FileCheck %s

// COM: Test that CSE can eliminate duplicate operations after converting to discardable attributes
// COM: Two Sqrt operations with same input but different onnx_node_name should be merged by CSE

func.func @test_cse_with_discardable_attr(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Sqrt"(%arg0) {onnx_node_name = "Sqrt_1"} : (tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Sqrt"(%arg0) {onnx_node_name = "Sqrt_2"} : (tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
onnx.Return %2 : tensor<10x10xf32>

// CHECK-LABEL: func @test_cse_with_discardable_attr
// CHECK: [[VAR_0_:%.+]] = "onnx.Sqrt"(%arg0)
// CHECK-NOT: "onnx.Sqrt"
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[VAR_0_]])
}

// -----

// COM: Test that CSE does NOT eliminate operations with non-discardable attributes
// COM: Two Sqrt operations with different unknown_attr should NOT be merged by CSE
// COM: because unknown_attr is not in the attr-names list and remains non-discardable

func.func @test_no_cse_with_non_discardable_attr(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Sqrt"(%arg0) {unknown_attr = "Sqrt_1"} : (tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Sqrt"(%arg0) {unknown_attr = "Sqrt_2"} : (tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
onnx.Return %2 : tensor<10x10xf32>

// CHECK-LABEL: func @test_no_cse_with_non_discardable_attr
// CHECK: [[VAR_0_:%.+]] = "onnx.Sqrt"(%arg0) {unknown_attr = "Sqrt_1"}
// CHECK: [[VAR_1_:%.+]] = "onnx.Sqrt"(%arg0) {unknown_attr = "Sqrt_2"}
// CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_0_]], [[VAR_1_]])
}
Loading