Skip to content

Commit e17f11e

Browse files
authored
Merge pull request #71 from JDAI-CV/fix_exception_in_ort
catch exceptions in GetSupportedNodes
2 parents bac3cce + 4efa112 commit e17f11e

File tree

4 files changed

+40
-26
lines changed

4 files changed

+40
-26
lines changed

ci/onnxruntime_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pool:
2727
steps:
2828
- checkout: self
2929
submodules: true
30-
- script: git clone --recursive --branch fix_android_build https://github.com/daquexian/onnxruntime $(Agent.HomeDirectory)/onnxruntime
30+
- script: git clone --recursive --branch android https://github.com/daquexian/onnxruntime $(Agent.HomeDirectory)/onnxruntime
3131
displayName: Clone ONNX Runtime
3232
- script: rm -rf $(Agent.HomeDirectory)/onnxruntime/cmake/external/DNNLibrary && cp -r $(Build.SourcesDirectory) $(Agent.HomeDirectory)/onnxruntime/cmake/external/DNNLibrary
3333
displayName: Copy latest DNNLibrary

include/tools/onnx2daq/OnnxConverter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class OnnxConverter {
215215
void Clear();
216216

217217
public:
218-
std::vector<std::vector<int>> GetSupportedNodes(
218+
expected<std::vector<std::vector<int>>, std::string> GetSupportedNodes(
219219
ONNX_NAMESPACE::ModelProto model_proto);
220220
void Convert(const std::string &model_str, const std::string &filepath,
221221
const std::string &table_file = "");

tools/getsupportednodes/getsupportednodes.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ int main(int argc, char *argv[])
1212
// FIXME: Handle the return value
1313
model_proto.ParseFromString(ss.str());
1414
dnn::OnnxConverter converter;
15-
PNT(converter.GetSupportedNodes(model_proto));
16-
return 0;
15+
const auto nodes = converter.GetSupportedNodes(model_proto);
16+
if (nodes) {
17+
const auto &supported_ops = nodes.value();
18+
PNT(supported_ops);
19+
return 0;
20+
} else {
21+
const auto &error = nodes.error();
22+
PNT(error);
23+
return 1;
24+
}
1725
}

tools/onnx2daq/OnnxConverter.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <common/data_types.h>
12
#include <common/Shaper.h>
23
#include <common/StrKeyMap.h>
34
#include <common/helper.h>
@@ -271,8 +272,9 @@ void OnnxConverter::HandleInitializer() {
271272
ONNX_NAMESPACE::TensorProto_DataType_INT64) {
272273
// TODO: shape of reshape layer
273274
} else {
274-
PNT(tensor.name(), tensor.data_type());
275-
DNN_ASSERT(false, "");
275+
DNN_ASSERT(false, "The data type \"" + std::to_string(tensor.data_type()) +
276+
"\" of tensor \"" +
277+
tensor.name() + "\" is not supported");
276278
}
277279
operands_.push_back(name);
278280
}
@@ -630,34 +632,38 @@ bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
630632
return false;
631633
}
632634

633-
std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
635+
expected<std::vector<std::vector<int>>, std::string> OnnxConverter::GetSupportedNodes(
634636
ONNX_NAMESPACE::ModelProto model_proto) {
635637
GOOGLE_PROTOBUF_VERIFY_VERSION;
636638
ONNX_NAMESPACE::shape_inference::InferShapes(model_proto);
637639
model_proto_ = model_proto;
638-
HandleInitializer();
640+
try {
641+
HandleInitializer();
639642

640-
std::vector<std::vector<int>> supported_node_vecs;
641-
std::vector<int> supported_node_vec;
642-
for (int i = 0; i < model_proto.graph().node_size(); i++) {
643-
bool supported;
644-
std::string error_msg;
645-
std::tie(supported, error_msg) =
646-
IsNodeSupported(model_proto, model_proto.graph().node(i));
647-
if (supported) {
648-
supported_node_vec.push_back(i);
649-
} else {
650-
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
651-
supported_node_vecs.push_back(supported_node_vec);
652-
supported_node_vec.clear();
643+
std::vector<std::vector<int>> supported_node_vecs;
644+
std::vector<int> supported_node_vec;
645+
for (int i = 0; i < model_proto.graph().node_size(); i++) {
646+
bool supported;
647+
std::string error_msg;
648+
std::tie(supported, error_msg) =
649+
IsNodeSupported(model_proto, model_proto.graph().node(i));
650+
if (supported) {
651+
supported_node_vec.push_back(i);
652+
} else {
653+
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
654+
supported_node_vecs.push_back(supported_node_vec);
655+
supported_node_vec.clear();
656+
}
653657
}
654658
}
659+
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
660+
supported_node_vecs.push_back(supported_node_vec);
661+
}
662+
Clear();
663+
return supported_node_vecs;
664+
} catch (std::exception &e) {
665+
return make_unexpected(e.what());
655666
}
656-
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
657-
supported_node_vecs.push_back(supported_node_vec);
658-
}
659-
Clear();
660-
return supported_node_vecs;
661667
}
662668

663669
void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,

0 commit comments

Comments
 (0)