Skip to content

[tmva][sofie] Add new operators with feedback from 1st NGT Hackathon #18348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ namespace TMVA {
namespace Experimental {
namespace SOFIE {

enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs };
enum class EBasicUnaryOperator {
kReciprocal,
kSqrt,
kNeg,
kExp,
kLog,
kSin,
kCos,
kAbs,
kRound
};

template <typename T, EBasicUnaryOperator Op>
struct UnaryOpTraits {
Expand Down Expand Up @@ -63,6 +73,15 @@ struct UnaryOpTraits<T, EBasicUnaryOperator::kAbs> {
static std::string Op(const std::string &X) { return "std::abs(" + X + ")"; }
};

template <typename T>
struct UnaryOpTraits<T, EBasicUnaryOperator::kRound> {
static std::string Name() { return "Round"; }
static std::string Op(const std::string &X)
{
return "(std::fabs(" + X + "- std::trunc(" + X + ")) == 0.5) ? std::trunc(" + X + ") : std::round(" + X + ");";
}
};

template <typename T, EBasicUnaryOperator Op>
class ROperator_BasicUnary final : public ROperator {
private:
Expand Down Expand Up @@ -109,7 +128,8 @@ public:
}

std::vector<std::string> GetStdLibs() override {
if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog) {
if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog ||
Op == EBasicUnaryOperator::kRound) {
return { std::string("cmath") };
} else {
return {};
Expand Down
74 changes: 74 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include "Cast_FromONNX.hxx"
#include "input_models/references/Cast.ref.hxx"

#include "CastLike_FromONNX.hxx"
#include "input_models/references/CastLike.ref.hxx"

#include "ReduceMean_FromONNX.hxx"
#include "input_models/references/ReduceMean.ref.hxx"

Expand Down Expand Up @@ -271,6 +274,9 @@
#include "Log_FromONNX.hxx"
#include "input_models/references/Log.ref.hxx"

#include "Round_FromONNX.hxx"
#include "input_models/references/Round.ref.hxx"

#include "Elu_FromONNX.hxx"
#include "input_models/references/Elu.ref.hxx"

Expand Down Expand Up @@ -323,6 +329,9 @@

#include "ScatterElements_FromONNX.hxx"

#include "Not_FromONNX.hxx"
#include "input_models/references/Not.ref.hxx"

#include "gtest/gtest.h"

constexpr float DEFAULT_TOLERANCE = 1e-3f;
Expand Down Expand Up @@ -518,6 +527,27 @@ TEST(ONNX, Neg)
}
}

TEST(ONNX, Not)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the standard input
std::vector<float> input({-0.7077, 1.0645, -0.8607, 0.2085, 4.5335, -3.4592});

TMVA_SOFIE_Not::Session s("Not_FromONNX.dat");
std::vector<float> output = s.infer(input.data());

// Checking output size
EXPECT_EQ(output.size(), sizeof(Not_ExpectedOutput::outputs) / sizeof(float));

float *correct = Not_ExpectedOutput::outputs;

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}

TEST(ONNX, Elu)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down Expand Up @@ -690,6 +720,28 @@ TEST(ONNX, Cast)
}
}

TEST(ONNX, CastLike)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the standard input
std::vector<int64_t> input_1({3, 23, -8, 1});

TMVA_SOFIE_CastLike::Session s("CastLike_FromONNX.dat");

auto output = s.infer(input_1.data());

// Checking output size
EXPECT_EQ(output.size(), sizeof(CastLike_ExpectedOutput::outputs) / sizeof(float));

float *correct = CastLike_ExpectedOutput::outputs;

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}

TEST(ONNX, Linear64)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down Expand Up @@ -808,6 +860,28 @@ TEST(ONNX, Log)
}
}

TEST(ONNX, Round)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the random input
std::vector<float> input({1.3, -4.5, 7.9, -2.6});

TMVA_SOFIE_Round::Session s("Round_FromONNX.dat");

std::vector<float> output = s.infer(input.data());

// Checking output size
EXPECT_EQ(output.size(), sizeof(Round_ExpectedOutput::outputs) / sizeof(float));

float *correct = Round_ExpectedOutput::outputs;

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}

TEST(ONNX, LinearWithLeakyRelu)
{
constexpr float TOLERANCE = 1;
Expand Down
15 changes: 15 additions & 0 deletions tmva/sofie/test/input_models/CastLike.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

 onnx-example:¥
Dinput2"Constant*0
value*$"Í̼@33 AÍÌLÀff>ÁB const_tensor 
"
input1
input2output"CastLikeCastLikeGraphZ
input1


b
output


B
Expand Down
12 changes: 12 additions & 0 deletions tmva/sofie/test/input_models/Not.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

 onnx-example:M

input1output"NotNotGraphZ
input1


b
output


B
13 changes: 13 additions & 0 deletions tmva/sofie/test/input_models/Round.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

 onnx-example:O

inputoutput"Round
RoundGraphZ
input


b
output


B
3 changes: 3 additions & 0 deletions tmva/sofie/test/input_models/references/CastLike.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace CastLike_ExpectedOutput {
float outputs[] = {3.0, 23.0, -8.0, 1.0};
} // namespace CastLike_ExpectedOutput
3 changes: 3 additions & 0 deletions tmva/sofie/test/input_models/references/Not.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Not_ExpectedOutput {
float outputs[] = {0.7077, -1.0645, 0.8607, -0.2085, -4.5335, 3.4592};
} // namespace Not_ExpectedOutput
3 changes: 3 additions & 0 deletions tmva/sofie/test/input_models/references/Round.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Round_ExpectedOutput {
float outputs[] = {1.0, -4.0, 8.0, -3.0};
} // namespace Round_ExpectedOutput
5 changes: 5 additions & 0 deletions tmva/sofie_parsers/src/ParseBasicUnary.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ ParserFuncSignature ParseAbs = [](RModelParser_ONNX &parser, const onnx::NodePro
return ParseBasicUnary<EBasicUnaryOperator::kAbs>(parser, nodeproto);
};

// Parse Round
ParserFuncSignature ParseRound = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kRound>(parser, nodeproto);
};

} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA
25 changes: 25 additions & 0 deletions tmva/sofie_parsers/src/ParseCast.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@ ParserFuncSignature ParseCast = [](RModelParser_ONNX &parser, const onnx::NodePr
return op;
};

ParserFuncSignature ParseCastLike = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
auto input_name = nodeproto.input(0);
auto target_type_tensor_name = nodeproto.input(1);
if (!parser.IsRegisteredTensorType(input_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has input tensor" + input_name +
" but its type is not yet registered");
}
if (!parser.IsRegisteredTensorType(target_type_tensor_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has target type tensor" + target_type_tensor_name +
" but its type is not yet registered");
}

std::unique_ptr<ROperator> op;
std::string target_type = ConvertTypeToString(parser.GetTensorType(target_type_tensor_name));
std::string output_name = nodeproto.output(0);
op.reset(new ROperator_Cast(target_type, input_name, output_name));

if (!parser.IsRegisteredTensorType(output_name)) {
ETensorType output_type = ConvertStringToType(target_type);
parser.RegisterTensorType(output_name, output_type);
}

return op;
};

} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA
7 changes: 7 additions & 0 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ extern ParserFuncSignature ParseLog;
extern ParserFuncSignature ParseSin;
extern ParserFuncSignature ParseCos;
extern ParserFuncSignature ParseAbs;
extern ParserFuncSignature ParseRound;

// Binary operators
extern ParserFuncSignature ParseAdd;
extern ParserFuncSignature ParseSub;
Expand Down Expand Up @@ -69,6 +71,7 @@ extern ParserFuncSignature ParseIdentity;
extern ParserFuncSignature ParseSoftmax;
extern ParserFuncSignature ParseConcat;
extern ParserFuncSignature ParseCast;
extern ParserFuncSignature ParseCastLike;
extern ParserFuncSignature ParseExpand;
extern ParserFuncSignature ParseShape;
extern ParserFuncSignature ParseMatMul;
Expand Down Expand Up @@ -158,11 +161,14 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
RegisterOperator("Sqrt", ParseSqrt);
RegisterOperator("Reciprocal", ParseReciprocal);
RegisterOperator("Neg", ParseNeg);
RegisterOperator("Not", ParseNeg);
RegisterOperator("Exp", ParseExp);
RegisterOperator("Log", ParseLog);
RegisterOperator("Sin", ParseSin);
RegisterOperator("Cos", ParseCos);
RegisterOperator("Abs", ParseAbs);
RegisterOperator("Round", ParseRound);

// Binary operators
RegisterOperator("Add", ParseAdd);
RegisterOperator("Sub", ParseSub);
Expand Down Expand Up @@ -190,6 +196,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
RegisterOperator("Constant", ParseConstant);
RegisterOperator("ConstantOfShape", ParseConstant);
RegisterOperator("Cast", ParseCast);
RegisterOperator("CastLike", ParseCastLike);
RegisterOperator("Concat", ParseConcat);
RegisterOperator("Conv", ParseConv);
RegisterOperator("ConvTranspose", ParseConvTranspose);
Expand Down
Loading