|
| 1 | +#include <common/data_types.h> |
1 | 2 | #include <common/Shaper.h> |
2 | 3 | #include <common/StrKeyMap.h> |
3 | 4 | #include <common/helper.h> |
@@ -271,8 +272,9 @@ void OnnxConverter::HandleInitializer() { |
271 | 272 | ONNX_NAMESPACE::TensorProto_DataType_INT64) { |
272 | 273 | // TODO: shape of reshape layer |
273 | 274 | } else { |
274 | | - PNT(tensor.name(), tensor.data_type()); |
275 | | - DNN_ASSERT(false, ""); |
| 275 | + DNN_ASSERT(false, "The data type \"" + std::to_string(tensor.data_type()) + |
| 276 | + "\" of tensor \"" + |
| 277 | + tensor.name() + "\" is not supported"); |
276 | 278 | } |
277 | 279 | operands_.push_back(name); |
278 | 280 | } |
@@ -630,34 +632,38 @@ bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec, |
630 | 632 | return false; |
631 | 633 | } |
632 | 634 |
|
633 | | -std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes( |
| 635 | +expected<std::vector<std::vector<int>>, std::string> OnnxConverter::GetSupportedNodes( |
634 | 636 | ONNX_NAMESPACE::ModelProto model_proto) { |
635 | 637 | GOOGLE_PROTOBUF_VERIFY_VERSION; |
636 | 638 | ONNX_NAMESPACE::shape_inference::InferShapes(model_proto); |
637 | 639 | model_proto_ = model_proto; |
638 | | - HandleInitializer(); |
| 640 | + try { |
| 641 | + HandleInitializer(); |
639 | 642 |
|
640 | | - std::vector<std::vector<int>> supported_node_vecs; |
641 | | - std::vector<int> supported_node_vec; |
642 | | - for (int i = 0; i < model_proto.graph().node_size(); i++) { |
643 | | - bool supported; |
644 | | - std::string error_msg; |
645 | | - std::tie(supported, error_msg) = |
646 | | - IsNodeSupported(model_proto, model_proto.graph().node(i)); |
647 | | - if (supported) { |
648 | | - supported_node_vec.push_back(i); |
649 | | - } else { |
650 | | - if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { |
651 | | - supported_node_vecs.push_back(supported_node_vec); |
652 | | - supported_node_vec.clear(); |
| 643 | + std::vector<std::vector<int>> supported_node_vecs; |
| 644 | + std::vector<int> supported_node_vec; |
| 645 | + for (int i = 0; i < model_proto.graph().node_size(); i++) { |
| 646 | + bool supported; |
| 647 | + std::string error_msg; |
| 648 | + std::tie(supported, error_msg) = |
| 649 | + IsNodeSupported(model_proto, model_proto.graph().node(i)); |
| 650 | + if (supported) { |
| 651 | + supported_node_vec.push_back(i); |
| 652 | + } else { |
| 653 | + if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { |
| 654 | + supported_node_vecs.push_back(supported_node_vec); |
| 655 | + supported_node_vec.clear(); |
| 656 | + } |
653 | 657 | } |
654 | 658 | } |
| 659 | + if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { |
| 660 | + supported_node_vecs.push_back(supported_node_vec); |
| 661 | + } |
| 662 | + Clear(); |
| 663 | + return supported_node_vecs; |
| 664 | + } catch (std::exception &e) { |
| 665 | + return make_unexpected(e.what()); |
655 | 666 | } |
656 | | - if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { |
657 | | - supported_node_vecs.push_back(supported_node_vec); |
658 | | - } |
659 | | - Clear(); |
660 | | - return supported_node_vecs; |
661 | 667 | } |
662 | 668 |
|
663 | 669 | void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto, |
|
0 commit comments