-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathParseCast.cxx
63 lines (51 loc) · 2.33 KB
/
ParseCast.cxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include "TMVA/RModelParser_ONNX.hxx"
#include "TMVA/ROperator_Cast.hxx"
#include "onnx_proto3.pb.h"
namespace TMVA {
namespace Experimental {
namespace SOFIE {
ParserFuncSignature ParseCast = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
auto input_name = nodeproto.input(0);
if (!parser.IsRegisteredTensorType(input_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has input tensor" + input_name +
" but its type is not yet registered");
}
std::unique_ptr<ROperator> op;
std::string attr_type;
for (int_t i = 0; i < nodeproto.attribute_size(); i++) {
std::string attribute_name = nodeproto.attribute(i).name();
if (attribute_name == "to")
attr_type = ConvertTypeToString(static_cast<ETensorType>(nodeproto.attribute(i).i()));
}
std::string output_name = nodeproto.output(0);
op.reset(new ROperator_Cast(attr_type, nodeproto.input(0), output_name));
if (!parser.IsRegisteredTensorType(output_name)) {
ETensorType output_type = ConvertStringToType(attr_type);
parser.RegisterTensorType(output_name, output_type);
}
return op;
};
ParserFuncSignature ParseCastLike = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
auto input_name = nodeproto.input(0);
auto target_type_tensor_name = nodeproto.input(1);
if (!parser.IsRegisteredTensorType(input_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has input tensor" + input_name +
" but its type is not yet registered");
}
if (!parser.IsRegisteredTensorType(target_type_tensor_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has target type tensor" + target_type_tensor_name +
" but its type is not yet registered");
}
std::unique_ptr<ROperator> op;
std::string target_type = ConvertTypeToString(parser.GetTensorType(target_type_tensor_name));
std::string output_name = nodeproto.output(0);
op.reset(new ROperator_Cast(target_type, input_name, output_name));
if (!parser.IsRegisteredTensorType(output_name)) {
ETensorType output_type = ConvertStringToType(target_type);
parser.RegisterTensorType(output_name, output_type);
}
return op;
};
} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA