Skip to content

Commit 0fd683c

Browse files
committed
[tmva][sofie] fix round operation and add tests
1 parent b8a1bc0 commit 0fd683c

File tree

10 files changed

+143
-8
lines changed

10 files changed

+143
-8
lines changed

Diff for: tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx

+17-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,17 @@ namespace TMVA {
99
namespace Experimental {
1010
namespace SOFIE {
1111

12-
enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs };
12+
enum class EBasicUnaryOperator {
13+
kReciprocal,
14+
kSqrt,
15+
kNeg,
16+
kExp,
17+
kLog,
18+
kSin,
19+
kCos,
20+
kAbs,
21+
kRound
22+
};
1323

1424
template <typename T, EBasicUnaryOperator Op>
1525
struct UnaryOpTraits {
@@ -66,7 +76,10 @@ struct UnaryOpTraits<T, EBasicUnaryOperator::kAbs> {
6676
template <typename T>
6777
struct UnaryOpTraits<T, EBasicUnaryOperator::kRound> {
6878
static std::string Name() { return "Round"; }
69-
static std::string Op(const std::string &X) { return "std::round(" + X + ")"; }
79+
static std::string Op(const std::string &X)
80+
{
81+
return "(std::fabs(" + X + "- std::trunc(" + X + ")) == 0.5) ? std::trunc(" + X + ") : std::round(" + X + ");";
82+
}
7083
};
7184

7285
template <typename T, EBasicUnaryOperator Op>
@@ -115,7 +128,8 @@ public:
115128
}
116129

117130
std::vector<std::string> GetStdLibs() override {
118-
if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog) {
131+
if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog ||
132+
Op == EBasicUnaryOperator::kRound) {
119133
return { std::string("cmath") };
120134
} else {
121135
return {};

Diff for: tmva/sofie/test/TestCustomModelsFromONNX.cxx

+74
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#include "Cast_FromONNX.hxx"
2828
#include "input_models/references/Cast.ref.hxx"
2929

30+
#include "CastLike_FromONNX.hxx"
31+
#include "input_models/references/CastLike.ref.hxx"
32+
3033
#include "ReduceMean_FromONNX.hxx"
3134
#include "input_models/references/ReduceMean.ref.hxx"
3235

@@ -271,6 +274,9 @@
271274
#include "Log_FromONNX.hxx"
272275
#include "input_models/references/Log.ref.hxx"
273276

277+
#include "Round_FromONNX.hxx"
278+
#include "input_models/references/Round.ref.hxx"
279+
274280
#include "Elu_FromONNX.hxx"
275281
#include "input_models/references/Elu.ref.hxx"
276282

@@ -323,6 +329,9 @@
323329

324330
#include "ScatterElements_FromONNX.hxx"
325331

332+
#include "Not_FromONNX.hxx"
333+
#include "input_models/references/Not.ref.hxx"
334+
326335
#include "gtest/gtest.h"
327336

328337
constexpr float DEFAULT_TOLERANCE = 1e-3f;
@@ -518,6 +527,27 @@ TEST(ONNX, Neg)
518527
}
519528
}
520529

530+
TEST(ONNX, Not)
531+
{
532+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
533+
534+
// Preparing the standard input
535+
std::vector<float> input({-0.7077, 1.0645, -0.8607, 0.2085, 4.5335, -3.4592});
536+
537+
TMVA_SOFIE_Not::Session s("Not_FromONNX.dat");
538+
std::vector<float> output = s.infer(input.data());
539+
540+
// Checking output size
541+
EXPECT_EQ(output.size(), sizeof(Not_ExpectedOutput::outputs) / sizeof(float));
542+
543+
float *correct = Not_ExpectedOutput::outputs;
544+
545+
// Checking every output value, one by one
546+
for (size_t i = 0; i < output.size(); ++i) {
547+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
548+
}
549+
}
550+
521551
TEST(ONNX, Elu)
522552
{
523553
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
@@ -690,6 +720,28 @@ TEST(ONNX, Cast)
690720
}
691721
}
692722

723+
TEST(ONNX, CastLike)
724+
{
725+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
726+
727+
// Preparing the standard input
728+
std::vector<int64_t> input_1({3, 23, -8, 1});
729+
730+
TMVA_SOFIE_CastLike::Session s("CastLike_FromONNX.dat");
731+
732+
auto output = s.infer(input_1.data());
733+
734+
// Checking output size
735+
EXPECT_EQ(output.size(), sizeof(CastLike_ExpectedOutput::outputs) / sizeof(float));
736+
737+
float *correct = CastLike_ExpectedOutput::outputs;
738+
739+
// Checking every output value, one by one
740+
for (size_t i = 0; i < output.size(); ++i) {
741+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
742+
}
743+
}
744+
693745
TEST(ONNX, Linear64)
694746
{
695747
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
@@ -808,6 +860,28 @@ TEST(ONNX, Log)
808860
}
809861
}
810862

863+
TEST(ONNX, Round)
864+
{
865+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
866+
867+
// Preparing the random input
868+
std::vector<float> input({1.3, -4.5, 7.9, -2.6});
869+
870+
TMVA_SOFIE_Round::Session s("Round_FromONNX.dat");
871+
872+
std::vector<float> output = s.infer(input.data());
873+
874+
// Checking output size
875+
EXPECT_EQ(output.size(), sizeof(Round_ExpectedOutput::outputs) / sizeof(float));
876+
877+
float *correct = Round_ExpectedOutput::outputs;
878+
879+
// Checking every output value, one by one
880+
for (size_t i = 0; i < output.size(); ++i) {
881+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
882+
}
883+
}
884+
811885
TEST(ONNX, LinearWithLeakyRelu)
812886
{
813887
constexpr float TOLERANCE = 1;

Diff for: tmva/sofie/test/input_models/CastLike.onnx

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+

2+
 onnx-example:�
3+
Dinput2"Constant*0
4+
value*$"�̼@33 A��L�ff>�B const_tensor�
5+
"
6+
input1
7+
input2output"CastLikeCastLikeGraphZ
8+
input1
9+
10+

11+
b
12+
output
13+
14+

15+
B

Diff for: tmva/sofie/test/input_models/Not.onnx

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+

2+
 onnx-example:M
3+

4+
input1output"NotNotGraphZ
5+
input1
6+
7+

8+
b
9+
output
10+
11+

12+
B

Diff for: tmva/sofie/test/input_models/Round.onnx

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
 onnx-example:O
3+

4+
inputoutput"Round
5+
RoundGraphZ
6+
input
7+
8+

9+
b
10+
output
11+
12+

13+
B
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
namespace CastLike_ExpectedOutput {
2+
float outputs[] = {3.0, 23.0, -8.0, 1.0};
3+
} // namespace CastLike_ExpectedOutput

Diff for: tmva/sofie/test/input_models/references/Not.ref.hxx

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
namespace Not_ExpectedOutput {
2+
float outputs[] = {0.7077, -1.0645, 0.8607, -0.2085, -4.5335, 3.4592};
3+
} // namespace Not_ExpectedOutput
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
namespace Round_ExpectedOutput {
2+
float outputs[] = {1, -4, 8, -3};
3+
} // namespace Round_ExpectedOutput

Diff for: tmva/sofie_parsers/src/ParseBasicUnary.cxx

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ ParserFuncSignature ParseAbs = [](RModelParser_ONNX &parser, const onnx::NodePro
8181
};
8282

8383
// Parse Round
84-
ParserFuncSignature ParseAbs = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
84+
ParserFuncSignature ParseRound = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
8585
return ParseBasicUnary<EBasicUnaryOperator::kRound>(parser, nodeproto);
8686
};
8787

Diff for: tmva/sofie_parsers/src/ParseCast.cxx

+2-4
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,18 @@ ParserFuncSignature ParseCastLike = [](RModelParser_ONNX &parser, const onnx::No
4646
}
4747

4848
std::unique_ptr<ROperator> op;
49-
std::string target_type = parser.GetTensorType(target_type_tensor_name);
49+
std::string target_type = ConvertTypeToString(parser.GetTensorType(target_type_tensor_name));
5050
std::string output_name = nodeproto.output(0);
5151
op.reset(new ROperator_Cast(target_type, input_name, output_name));
5252

5353
if (!parser.IsRegisteredTensorType(output_name)) {
54-
ETensorType output_type = ConvertStringToType(attr_type);
54+
ETensorType output_type = ConvertStringToType(target_type);
5555
parser.RegisterTensorType(output_name, output_type);
5656
}
5757

5858
return op;
5959
};
6060

61-
62-
6361
} // namespace SOFIE
6462
} // namespace Experimental
6563
} // namespace TMVA

0 commit comments

Comments
 (0)