diff --git a/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx b/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx index 15fb4a01ebe15..d6af4bae69c20 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx @@ -9,7 +9,17 @@ namespace TMVA { namespace Experimental { namespace SOFIE { -enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs }; +enum class EBasicUnaryOperator { + kReciprocal, + kSqrt, + kNeg, + kExp, + kLog, + kSin, + kCos, + kAbs, + kRound +}; template struct UnaryOpTraits { @@ -63,6 +73,15 @@ struct UnaryOpTraits { static std::string Op(const std::string &X) { return "std::abs(" + X + ")"; } }; +template +struct UnaryOpTraits { + static std::string Name() { return "Round"; } + static std::string Op(const std::string &X) + { + return "(std::fabs(" + X + "- std::trunc(" + X + ")) == 0.5) ? std::trunc(" + X + ") : std::round(" + X + ");"; + } +}; + template class ROperator_BasicUnary final : public ROperator { private: @@ -109,7 +128,8 @@ public: } std::vector GetStdLibs() override { - if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog) { + if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog || + Op == EBasicUnaryOperator::kRound) { return { std::string("cmath") }; } else { return {}; diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index 2caab4fb2cc4b..3af48308b6698 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -27,6 +27,9 @@ #include "Cast_FromONNX.hxx" #include "input_models/references/Cast.ref.hxx" +#include "CastLike_FromONNX.hxx" +#include "input_models/references/CastLike.ref.hxx" + #include "ReduceMean_FromONNX.hxx" #include "input_models/references/ReduceMean.ref.hxx" @@ -271,6 +274,9 @@ #include "Log_FromONNX.hxx" #include "input_models/references/Log.ref.hxx" +#include "Round_FromONNX.hxx" +#include "input_models/references/Round.ref.hxx" + #include "Elu_FromONNX.hxx" #include "input_models/references/Elu.ref.hxx" @@ -323,6 +329,9 @@ #include "ScatterElements_FromONNX.hxx" +#include "Not_FromONNX.hxx" +#include "input_models/references/Not.ref.hxx" + #include "gtest/gtest.h" constexpr float DEFAULT_TOLERANCE = 1e-3f; @@ -518,6 +527,27 @@ TEST(ONNX, Neg) } } + TEST(ONNX, Not) + { + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Preparing the standard input + std::vector input({-0.7077, 1.0645, -0.8607, 0.2085, 4.5335, -3.4592}); + + TMVA_SOFIE_Not::Session s("Not_FromONNX.dat"); + std::vector output = s.infer(input.data()); + + // Checking output size + EXPECT_EQ(output.size(), sizeof(Not_ExpectedOutput::outputs) / sizeof(float)); + + float *correct = Not_ExpectedOutput::outputs; + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); + } + } + TEST(ONNX, Elu) { constexpr float TOLERANCE = DEFAULT_TOLERANCE; @@ -690,6 +720,28 @@ TEST(ONNX, Cast) } } +TEST(ONNX, CastLike) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Preparing the standard input + std::vector input_1({3, 23, -8, 1}); + + TMVA_SOFIE_CastLike::Session s("CastLike_FromONNX.dat"); + + auto output = s.infer(input_1.data()); + + // Checking output size + EXPECT_EQ(output.size(), sizeof(CastLike_ExpectedOutput::outputs) / sizeof(float)); + + float *correct = CastLike_ExpectedOutput::outputs; + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); + } +} + TEST(ONNX, Linear64) { constexpr float TOLERANCE = DEFAULT_TOLERANCE; @@ -808,6 +860,28 @@ TEST(ONNX, Log) } } +TEST(ONNX, Round) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Preparing the random input + std::vector input({1.3, -4.5, 7.9, -2.6}); + + TMVA_SOFIE_Round::Session s("Round_FromONNX.dat"); + + std::vector output = s.infer(input.data()); + + // Checking output size + EXPECT_EQ(output.size(), sizeof(Round_ExpectedOutput::outputs) / sizeof(float)); + + float *correct = Round_ExpectedOutput::outputs; + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); + } +} + TEST(ONNX, LinearWithLeakyRelu) { constexpr float TOLERANCE = 1; diff --git a/tmva/sofie/test/input_models/CastLike.onnx b/tmva/sofie/test/input_models/CastLike.onnx new file mode 100644 index 0000000000000..6d2c254c91e1f --- /dev/null +++ b/tmva/sofie/test/input_models/CastLike.onnx @@ -0,0 +1,15 @@ + + onnx-example:� +Dinput2"Constant*0 +value*$"�̼@33 A��L�ff>�B const_tensor� +" +input1 +input2output"CastLike CastLikeGraphZ +input1 + + +b +output + + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/Not.onnx b/tmva/sofie/test/input_models/Not.onnx new file mode 100644 index 0000000000000..54d7c080ea438 --- /dev/null +++ b/tmva/sofie/test/input_models/Not.onnx @@ -0,0 +1,12 @@ + + onnx-example:M + +input1output"NotNotGraphZ +input1 + + +b +output + + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/Round.onnx b/tmva/sofie/test/input_models/Round.onnx new file mode 100644 index 0000000000000..32a3ddaf6a641 --- /dev/null +++ b/tmva/sofie/test/input_models/Round.onnx @@ -0,0 +1,13 @@ + + onnx-example:O + +inputoutput"Round +RoundGraphZ +input + + +b +output + + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/references/CastLike.ref.hxx b/tmva/sofie/test/input_models/references/CastLike.ref.hxx new file mode 100644 index 0000000000000..fc7e9b76626a9 --- /dev/null +++ b/tmva/sofie/test/input_models/references/CastLike.ref.hxx @@ -0,0 +1,3 @@ +namespace CastLike_ExpectedOutput { +float outputs[] = {3.0, 23.0, -8.0, 1.0}; +} // namespace CastLike_ExpectedOutput diff --git a/tmva/sofie/test/input_models/references/Not.ref.hxx b/tmva/sofie/test/input_models/references/Not.ref.hxx new file mode 100644 index 0000000000000..e420e25c57109 --- /dev/null +++ b/tmva/sofie/test/input_models/references/Not.ref.hxx @@ -0,0 +1,3 @@ +namespace Not_ExpectedOutput { +float outputs[] = {0.7077, -1.0645, 0.8607, -0.2085, -4.5335, 3.4592}; +} // namespace Not_ExpectedOutput diff --git a/tmva/sofie/test/input_models/references/Round.ref.hxx b/tmva/sofie/test/input_models/references/Round.ref.hxx new file mode 100644 index 0000000000000..6548a29c93190 --- /dev/null +++ b/tmva/sofie/test/input_models/references/Round.ref.hxx @@ -0,0 +1,3 @@ +namespace Round_ExpectedOutput { +float outputs[] = {1.0, -4.0, 8.0, -3.0}; +} // namespace Round_ExpectedOutput diff --git a/tmva/sofie_parsers/src/ParseBasicUnary.cxx b/tmva/sofie_parsers/src/ParseBasicUnary.cxx index 1292161c585c2..322edaac8e979 100644 --- a/tmva/sofie_parsers/src/ParseBasicUnary.cxx +++ b/tmva/sofie_parsers/src/ParseBasicUnary.cxx @@ -80,6 +80,11 @@ ParserFuncSignature ParseAbs = [](RModelParser_ONNX &parser, const onnx::NodePro return ParseBasicUnary(parser, nodeproto); }; +// Parse Round +ParserFuncSignature ParseRound = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + return ParseBasicUnary(parser, nodeproto); +}; + } // namespace SOFIE } // namespace Experimental } // namespace TMVA diff --git a/tmva/sofie_parsers/src/ParseCast.cxx b/tmva/sofie_parsers/src/ParseCast.cxx index 0e07fd7d164fc..2a487f703f48e 100644 --- a/tmva/sofie_parsers/src/ParseCast.cxx +++ b/tmva/sofie_parsers/src/ParseCast.cxx @@ -33,6 +33,31 @@ ParserFuncSignature ParseCast = [](RModelParser_ONNX &parser, const onnx::NodePr return op; }; +ParserFuncSignature ParseCastLike = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + auto input_name = nodeproto.input(0); + auto target_type_tensor_name = nodeproto.input(1); + if (!parser.IsRegisteredTensorType(input_name)) { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has input tensor" + input_name + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(target_type_tensor_name)) { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has target type tensor" + target_type_tensor_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + std::string target_type = ConvertTypeToString(parser.GetTensorType(target_type_tensor_name)); + std::string output_name = nodeproto.output(0); + op.reset(new ROperator_Cast(target_type, input_name, output_name)); + + if (!parser.IsRegisteredTensorType(output_name)) { + ETensorType output_type = ConvertStringToType(target_type); + parser.RegisterTensorType(output_name, output_type); + } + + return op; +}; + } // namespace SOFIE } // namespace Experimental } // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 7b4ade2b6bc09..bc65e8258df3c 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -25,6 +25,8 @@ extern ParserFuncSignature ParseLog; extern ParserFuncSignature ParseSin; extern ParserFuncSignature ParseCos; extern ParserFuncSignature ParseAbs; +extern ParserFuncSignature ParseRound; + // Binary operators extern ParserFuncSignature ParseAdd; extern ParserFuncSignature ParseSub; @@ -69,6 +71,7 @@ extern ParserFuncSignature ParseIdentity; extern ParserFuncSignature ParseSoftmax; extern ParserFuncSignature ParseConcat; extern ParserFuncSignature ParseCast; +extern ParserFuncSignature ParseCastLike; extern ParserFuncSignature ParseExpand; extern ParserFuncSignature ParseShape; extern ParserFuncSignature ParseMatMul; @@ -158,11 +161,14 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("Sqrt", ParseSqrt); RegisterOperator("Reciprocal", ParseReciprocal); RegisterOperator("Neg", ParseNeg); + RegisterOperator("Not", ParseNeg); RegisterOperator("Exp", ParseExp); RegisterOperator("Log", ParseLog); RegisterOperator("Sin", ParseSin); RegisterOperator("Cos", ParseCos); RegisterOperator("Abs", ParseAbs); + RegisterOperator("Round", ParseRound); + // Binary operators RegisterOperator("Add", ParseAdd); RegisterOperator("Sub", ParseSub); @@ -190,6 +196,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("Constant", ParseConstant); RegisterOperator("ConstantOfShape", ParseConstant); RegisterOperator("Cast", ParseCast); + RegisterOperator("CastLike", ParseCastLike); RegisterOperator("Concat", ParseConcat); RegisterOperator("Conv", ParseConv); RegisterOperator("ConvTranspose", ParseConvTranspose);