-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathParseBasicUnary.cxx
90 lines (73 loc) · 3.03 KB
/
ParseBasicUnary.cxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include "TMVA/RModelParser_ONNX.hxx"
#include "TMVA/ROperator_BasicUnary.hxx"
#include "onnx_proto3.pb.h"
namespace TMVA {
namespace Experimental {
namespace SOFIE {
template <EBasicUnaryOperator Op>
std::unique_ptr<ROperator> ParseBasicUnary(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
{
ETensorType input_type = ETensorType::UNDEFINED;
std::string input_name = nodeproto.input(0);
if (parser.IsRegisteredTensorType(input_name)) {
input_type = parser.GetTensorType(input_name);
} else {
throw
std::runtime_error("TMVA::SOFIE ONNX Parser Unary op has input tensor " + input_name +
" but its type is not yet registered");
}
std::unique_ptr<ROperator> op;
std::string output_name = nodeproto.output(0);
switch (input_type) {
case ETensorType::FLOAT:
op.reset(new ROperator_BasicUnary<float, Op>(input_name, output_name));
break;
default:
throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " +
std::to_string(static_cast<int>(input_type)));
}
// Infer the output type
if (!parser.IsRegisteredTensorType(output_name)) {
parser.RegisterTensorType(output_name, input_type);
}
return op;
};
// Parse Sqrt
ParserFuncSignature ParseSqrt = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kSqrt>(parser, nodeproto);
};
// Parse Reciprocal
ParserFuncSignature ParseReciprocal = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kReciprocal>(parser, nodeproto);
};
// Parse Neg
ParserFuncSignature ParseNeg = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kNeg>(parser, nodeproto);
};
// Parse Exp
ParserFuncSignature ParseExp = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kExp>(parser, nodeproto);
};
// Parse Log
ParserFuncSignature ParseLog = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kLog>(parser, nodeproto);
};
// Parse Sin
ParserFuncSignature ParseSin = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kSin>(parser, nodeproto);
};
// Parse Cos
ParserFuncSignature ParseCos = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kCos>(parser, nodeproto);
};
// Parse Abs
ParserFuncSignature ParseAbs = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kAbs>(parser, nodeproto);
};
// Parse Round
ParserFuncSignature ParseRound = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kRound>(parser, nodeproto);
};
} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA