Skip to content

Commit 77c4e43

Browse files
committed
[onnx2circle] Add option to list supported operators
This commit adds --list-operators option to the onnx2circle tool which lists all supported ONNX operators with their opset numbers. Such list should help in ONNX model operators conversion in case when converted model uses not supported operators versions. > $ onnx2circle --list-operators > AveragePool-1 > AveragePool-7 > AveragePool-10 > ... > MatMul-1 > MatMul-9 > ... ONE-DCO-1.0-Signed-off-by: Arkadiusz Bokowy <a.bokowy@samsung.com>
1 parent 30a67e1 commit 77c4e43

File tree

7 files changed

+87
-23
lines changed

7 files changed

+87
-23
lines changed

compiler/mir/include/mir_onnx_importer/ONNXImporterImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121

2222
#include <memory>
2323
#include <string>
24+
#include <vector>
2425

2526
namespace mir_onnx
2627
{
2728

29+
std::vector<std::string> getSupportedOperators();
2830
std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
2931
std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
3032
// TODO Remove after changing all uses.

compiler/mir/src/mir_onnx_importer/ONNXImporterImpl.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "ONNXImporterImpl.h"
1818
#include "ONNXHelpers.h"
1919
#include "ONNXOpRegistration.h"
20+
#include "ONNXNodeConverterRegistry.h"
2021
#include "onnx/onnx.pb.h"
2122

2223
#include "mir/Shape.h"
@@ -32,7 +33,9 @@
3233
#include <functional>
3334
#include <iostream>
3435
#include <memory>
36+
#include <string>
3537
#include <utility>
38+
#include <vector>
3639

3740
namespace mir_onnx
3841
{
@@ -45,6 +48,7 @@ class ONNXImporterImpl final
4548
public:
4649
ONNXImporterImpl();
4750
~ONNXImporterImpl();
51+
std::vector<std::string> getSupportedOps();
4852
/// @brief Load the model and convert it into a MIR Graph.
4953
std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
5054
std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
@@ -104,6 +108,16 @@ void loadModelFromTextFile(const std::string &filename, onnx::ModelProto *model)
104108
throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
105109
}
106110

111+
std::vector<std::string> ONNXImporterImpl::getSupportedOps()
112+
{
113+
std::vector<std::string> ops;
114+
for (const auto &op : NodeConverterRegistry::getInstance().getSupportedOperators())
115+
{
116+
ops.push_back(op.first + "-" + std::to_string(op.second));
117+
}
118+
return ops;
119+
}
120+
107121
std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromBinaryFile(const std::string &filename)
108122
{
109123
_model = std::make_unique<onnx::ModelProto>();
@@ -124,7 +138,7 @@ std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromTextFile(const std:
124138

125139
void ONNXImporterImpl::collectUnsupportedOps()
126140
{
127-
std::set<std::pair<std::string, int64_t>> problems_op_set;
141+
std::set<NodeConverterRegistry::Operator> problems_op_set;
128142

129143
for (int i = 0; i < _model->graph().node_size(); i++)
130144
{
@@ -134,7 +148,7 @@ void ONNXImporterImpl::collectUnsupportedOps()
134148
auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain());
135149

136150
NodeConverterRegistry::ConverterFunc converter =
137-
NodeConverterRegistry::getInstance().lookup(op_type, opset);
151+
NodeConverterRegistry::getInstance().lookup({op_type, opset});
138152

139153
if (converter == nullptr)
140154
problems_op_set.emplace(op_type, opset);
@@ -143,7 +157,7 @@ void ONNXImporterImpl::collectUnsupportedOps()
143157
{
144158
std::cerr << "The following operators are not supported:\n";
145159
for (const auto &op : problems_op_set)
146-
std::cerr << op.first << " opset " << op.second << std::endl;
160+
std::cerr << op.first << "-" << op.second << std::endl;
147161
throw std::runtime_error("Unsupported operators found");
148162
}
149163
}
@@ -199,7 +213,7 @@ std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
199213
auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain());
200214
// Get converter
201215
NodeConverterRegistry::ConverterFunc converter =
202-
NodeConverterRegistry::getInstance().lookup(op_type, opset);
216+
NodeConverterRegistry::getInstance().lookup({op_type, opset});
203217
assert(converter != nullptr);
204218
converter(onnx_node, _converterCtx.get());
205219
}
@@ -220,6 +234,12 @@ std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
220234

221235
} // namespace
222236

237+
std::vector<std::string> getSupportedOperators()
238+
{
239+
ONNXImporterImpl importer;
240+
return importer.getSupportedOps();
241+
}
242+
223243
std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename)
224244
{
225245
ONNXImporterImpl importer;

compiler/mir/src/mir_onnx_importer/ONNXNodeConverterRegistry.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ void ConverterContext::setNodeOutputs(const onnx::NodeProto &onnx_node,
105105

106106
// NodeConverterRegistry
107107

108-
NodeConverterRegistry::ConverterFunc NodeConverterRegistry::lookup(const std::string &optype,
109-
int64_t opset) const
108+
NodeConverterRegistry::ConverterFunc
109+
NodeConverterRegistry::lookup(const NodeConverterRegistry::Operator &op) const
110110
{
111-
auto it = _converter_map.find(optype);
111+
auto it = _converter_map.find(op.first);
112112
if (it == _converter_map.end())
113113
{
114114
return nullptr;
@@ -117,8 +117,8 @@ NodeConverterRegistry::ConverterFunc NodeConverterRegistry::lookup(const std::st
117117
const VersionMap &conv_map = it->second;
118118

119119
auto res = std::lower_bound(
120-
conv_map.crbegin(), conv_map.crend(), opset,
121-
[](const VersionMap::value_type &pair, int64_t opset) { return pair.first > opset; });
120+
conv_map.crbegin(), conv_map.crend(), op.second,
121+
[](const VersionMap::value_type &pair, unsigned int opset) { return pair.first > opset; });
122122

123123
if (res == conv_map.crend())
124124
{
@@ -133,10 +133,34 @@ NodeConverterRegistry &NodeConverterRegistry::getInstance()
133133
return instance;
134134
}
135135

136-
void NodeConverterRegistry::registerConverter(const std::string &op_type, int64_t opset,
136+
void NodeConverterRegistry::registerConverter(const NodeConverterRegistry::Operator &op,
137137
NodeConverterRegistry::ConverterFunc conv)
138138
{
139-
_converter_map[op_type].emplace(opset, conv);
139+
_converter_map[op.first].emplace(op.second, conv);
140+
}
141+
142+
std::vector<NodeConverterRegistry::Operator> NodeConverterRegistry::getSupportedOperators() const
143+
{
144+
std::vector<Operator> ops;
145+
146+
for (const auto &op : _converter_map)
147+
{
148+
for (const auto &version : op.second)
149+
{
150+
// Get only supported operators
151+
if (version.second != nullptr)
152+
{
153+
ops.push_back({op.first, version.first});
154+
}
155+
}
156+
}
157+
158+
// Sort operators alphabetically for consistent output
159+
std::sort(ops.begin(), ops.end(), [](const auto &op1, const auto &op2) {
160+
return (op1.first == op2.first) ? (op1.second < op2.second) : (op1.first < op2.first);
161+
});
162+
163+
return ops;
140164
}
141165

142166
} // namespace mir_onnx

compiler/mir/src/mir_onnx_importer/ONNXNodeConverterRegistry.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,20 @@ class ConverterContext
6060
class NodeConverterRegistry
6161
{
6262
public:
63+
// @brief ONNX operator defined by its name and opset number.
64+
using Operator = std::pair<std::string, unsigned int>;
6365
using ConverterFunc = void (*)(const onnx::NodeProto &onnx_node, ConverterContext *context);
6466

6567
NodeConverterRegistry() = default;
6668

67-
ConverterFunc lookup(const std::string &optype, int64_t opset) const;
68-
void registerConverter(const std::string &op_type, int64_t opset, ConverterFunc conv);
69+
ConverterFunc lookup(const Operator &op) const;
70+
void registerConverter(const Operator &op, ConverterFunc conv);
71+
std::vector<Operator> getSupportedOperators() const;
6972

7073
static NodeConverterRegistry &getInstance();
7174

7275
private:
73-
using VersionMap = std::map<int64_t, ConverterFunc>;
76+
using VersionMap = std::map<unsigned int, ConverterFunc>;
7477

7578
std::unordered_map<std::string, VersionMap> _converter_map;
7679
};

compiler/mir/src/mir_onnx_importer/ONNXNodeConverterRegistry.test.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,42 @@ class NodeConverterRegsitryTest : public ::testing::Test
3030
protected:
3131
void SetUp() override
3232
{
33-
registry.registerConverter("dummy", 1, converterV1);
34-
registry.registerConverter("dummy", 3, converterV3);
35-
registry.registerConverter("dummy", 7, converterV7);
36-
registry.registerConverter("dummy", firstUnknownOpset, nullptr);
33+
registry.registerConverter({"dummy", 1}, converterV1);
34+
registry.registerConverter({"dummy", 3}, converterV3);
35+
registry.registerConverter({"dummy", 7}, converterV7);
36+
registry.registerConverter({"dummy", firstUnknownOpset}, nullptr);
3737
}
3838

3939
NodeConverterRegistry registry;
4040
};
4141

4242
TEST_F(NodeConverterRegsitryTest, existing_lookup_works)
4343
{
44-
auto res = registry.lookup("dummy", 1);
44+
auto res = registry.lookup({"dummy", 1});
4545
ASSERT_EQ(res, &converterV1);
4646
}
4747

4848
TEST_F(NodeConverterRegsitryTest, skipped_lookup_works)
4949
{
50-
auto res = registry.lookup("dummy", 2);
50+
auto res = registry.lookup({"dummy", 2});
5151
ASSERT_EQ(res, &converterV1);
5252
}
5353

5454
TEST_F(NodeConverterRegsitryTest, first_unknown_version_works)
5555
{
56-
auto res = registry.lookup("dummy", 14);
56+
auto res = registry.lookup({"dummy", 14});
5757
ASSERT_EQ(res, nullptr);
5858
}
5959

6060
TEST_F(NodeConverterRegsitryTest, lower_than_first_version)
6161
{
62-
auto res = registry.lookup("dummy", 0);
62+
auto res = registry.lookup({"dummy", 0});
6363
ASSERT_EQ(res, nullptr);
6464
}
65+
66+
TEST_F(NodeConverterRegsitryTest, registered_ops)
67+
{
68+
auto res = registry.getSupportedOperators();
69+
ASSERT_EQ(
70+
res, std::vector<NodeConverterRegistry::Operator>({{"dummy", 1}, {"dummy", 3}, {"dummy", 7}}));
71+
}

compiler/mir/src/mir_onnx_importer/ONNXOpRegistration.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ inline void registerSupportedOps()
6565
{
6666
auto &registry = NodeConverterRegistry::getInstance();
6767

68-
#define REG_CONVERTER(name, version, function) registry.registerConverter(name, version, function)
68+
#define REG_CONVERTER(name, version, function) registry.registerConverter({name, version}, function)
6969
#define REG(name, version) REG_CONVERTER(#name, version, convert##name##V##version)
7070
#define UNSUPPORTED(name, version) REG_CONVERTER(#name, version, nullptr)
7171

compiler/onnx2circle/src/onnx2circle.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ struct LoggingContext
6565
void print_help()
6666
{
6767
std::cerr << "Usage: onnx2circle <path/to/onnx> <path/to/circle/model> " << std::endl;
68+
std::cerr << " onnx2circle --list-operators" << std::endl;
6869
}
6970

7071
} // namespace
@@ -83,6 +84,13 @@ int main(int argc, char **argv)
8384

8485
LOGGER(l);
8586

87+
if (argc == 2 && std::string(argv[1]) == "--list-operators")
88+
{
89+
for (const auto &op : mir_onnx::getSupportedOperators())
90+
std::cout << op << std::endl;
91+
return 0;
92+
}
93+
8694
// TODO We need better args parsing in future
8795
if (!(argc == 3))
8896
{

0 commit comments

Comments
 (0)