From bb9cb6655a4bf5a0948b22454696742eabb1e192 Mon Sep 17 00:00:00 2001 From: Katie Graham Date: Wed, 6 Apr 2022 16:24:25 -0700 Subject: [PATCH 1/5] applied patch to expand operator layer macros --- .../lbann/operators/declare_stateless_op.hpp | 27 ++++ include/lbann/operators/math/binary.hpp | 57 ++++++--- .../operators/math/binary_with_constant.hpp | 115 +++++++++++++++++- include/lbann/operators/operator.hpp | 20 ++- 4 files changed, 199 insertions(+), 20 deletions(-) diff --git a/include/lbann/operators/declare_stateless_op.hpp b/include/lbann/operators/declare_stateless_op.hpp index 7ff0ebd531a..dc5d42eeb26 100644 --- a/include/lbann/operators/declare_stateless_op.hpp +++ b/include/lbann/operators/declare_stateless_op.hpp @@ -32,6 +32,16 @@ #include "lbann/proto/operators.pb.h" +#ifdef LBANN_HAS_ONNX +#define ADD_GET_ONNX_NODES_API() \ + std::vector get_onnx_nodes() const final \ + { \ + return get_onnx_nodes_impl(*this); \ + } +#else +#define ADD_GET_ONNX_NODES_API() +#endif // LBANN_HAS_ONNX + // These are all single-type operators. #define LBANN_DECLARE_STATELESS_OPERATOR(OP_NAME, OP_STRING) \ @@ -64,6 +74,7 @@ ar(::cereal::make_nvp("Operator", \ ::cereal::base_class(this))); \ } \ + ADD_GET_ONNX_NODES_API() \ void fp_compute(std::vector const& inputs, \ std::vector const& outputs) const final; \ void bp_compute( \ @@ -113,6 +124,7 @@ ar(::cereal::make_nvp("ElementwiseOperator", \ ::cereal::base_class(this))); \ } \ + ADD_GET_ONNX_NODES_API() \ \ private: \ void \ @@ -130,4 +142,19 @@ {} \ } +namespace lbann { + +#ifdef LBANN_HAS_ONNX +// Overloads of this function are used to implement the functions in +// the macro template above. +template +std::vector get_onnx_nodes_impl(OperatorT const& op) +{ + // The default assumption is that we don't know how to represent + // this operator in ONNX terms yet. + return {}; +} +#endif // LBANN_HAS_ONNX + +} // namespace lbann #endif // LBANN_INCLUDE_LBANN_OPERATORS_DECLARE_STATELESS_OP_HPP_INCLUDED diff --git a/include/lbann/operators/math/binary.hpp b/include/lbann/operators/math/binary.hpp index d9db70c21e9..9a639190736 100644 --- a/include/lbann/operators/math/binary.hpp +++ b/include/lbann/operators/math/binary.hpp @@ -29,34 +29,57 @@ #include "lbann/operators/declare_stateless_op.hpp" +#ifdef LBANN_HAS_ONNX +#define LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(OP_NAME, \ + OP_STRING, \ + OP_ONNX_NAME) \ + LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING); \ + template \ + std::vector get_onnx_nodes_impl( \ + OP_NAME##Operator const& op) \ + { \ + std::vector nodes(1UL); \ + nodes.front().set_op_type(OP_ONNX_NAME); \ + return nodes; \ + } +#else +#define LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(OP_NAME, \ + OP_STRING, \ + OP_ONNX_NAME) \ + LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING) +#endif // LBANN_HAS_ONNX + namespace lbann { // Arithmetic operations -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Add, "add"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Subtract, "subtract"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Multiply, "multiply"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Divide, "divide"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Mod, "modulo"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Pow, "power"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Add, "add", "Add") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Subtract, "subtract", "Sub") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Multiply, "multiply", "Mul") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Divide, "divide", "Div") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Mod, "modulo", "Mod") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Pow, "power", "Pow") LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(SafeDivide, "safe divide"); LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(SquaredDifference, "squared difference"); // Comparison operations -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Max, "maximum"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Min, "minimum"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Equal, "equal"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Max, "maximum", "Max") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Min, "minimum", "Min") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Equal, "equal", "Equal") LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(NotEqual, "not equal"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Less, "less than"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LessEqual, "less than or equal"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Greater, "greater than"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(GreaterEqual, - "greater than or equal"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Less, "less than", "Less") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LessEqual, + "less than or equal", + "LessOrEqual") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Greater, "greater than", "Greater") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(GreaterEqual, + "greater than or equal", + "GreaterOrEqual") // Logical operations -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalAnd, "logical and"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalOr, "logical or"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalXor, "logical xor"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalAnd, "logical and", "And") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalOr, "logical or", "Or") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalXor, "logical xor", "Xor") } // namespace lbann diff --git a/include/lbann/operators/math/binary_with_constant.hpp b/include/lbann/operators/math/binary_with_constant.hpp index c4d6dd7d6cc..024e4f7f8c2 100644 --- a/include/lbann/operators/math/binary_with_constant.hpp +++ b/include/lbann/operators/math/binary_with_constant.hpp @@ -32,7 +32,9 @@ #include "lbann/operators/elementwise_operator.hpp" #include "lbann/utils/cloneable.hpp" -#include "lbann/proto/operators.pb.h" +#include +#include + /** @file * @@ -50,6 +52,16 @@ #include "lbann/proto/operators.pb.h" +#ifdef LBANN_HAS_ONNX +#define ADD_GET_ONNX_NODES_API() \ + std::vector get_onnx_nodes() const final \ + { \ + return get_onnx_nodes_impl(*this); \ + } +#else +#define ADD_GET_ONNX_NODES_API() +#endif // LBANN_HAS_ONNX + // These are all single-type operators. #define LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(OP_NAME, OP_STRING) \ @@ -88,6 +100,7 @@ ::cereal::base_class(this)), \ CEREAL_NVP(m_constant)); \ } \ + ADD_GET_ONNX_NODES_API() \ DataT get_constant() const noexcept \ { \ return m_constant; \ @@ -123,7 +136,7 @@ namespace lbann { // x + c -- treated as commutative. LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(AddConstant, "add constant"); -// x + c -- treated as commutative. +// x * c -- treated as commutative. LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(Scale, "scale"); // x - C -- yes, could be "plus -C", but so could 7-4 be 7+-4, but @@ -149,5 +162,103 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterEqualConstant, LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant, "greater than constant"); +inline onnx::NodeProto get_constant_node(float val) +{ + onnx::NodeProto const_node; + auto* const_val = const_node.add_attribute(); + const_val->set_name("value_float"); + const_val->set_type(onnx::AttributeProto::FLOAT); + const_val->set_f(val); + return const_node; +} + +template +std::vector get_onnx_nodes_impl( + AddConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PostConstant"); + nodes.back().set_op_type("Add"); + return nodes; +} + +template +std::vector get_onnx_nodes_impl( + ScaleOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + SubtractConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + ConstantSubtractOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + MaxConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + MinConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + EqualConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + NotEqualConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + LessConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + LessEqualConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + GreaterConstantOperator const) +{ + return {}; +} + +template +std::vector get_onnx_nodes_impl( + GreaterEqualConstantOperator const) +{ + return {}; +} + } // namespace lbann #endif // LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED diff --git a/include/lbann/operators/operator.hpp b/include/lbann/operators/operator.hpp index 0bbd838f0dd..e9a092a19ba 100644 --- a/include/lbann/operators/operator.hpp +++ b/include/lbann/operators/operator.hpp @@ -43,6 +43,10 @@ #include +#ifdef LBANN_HAS_ONNX +#include +#endif + #include #include @@ -130,6 +134,10 @@ class Operator : public AbstractCloneableBase>, template void serialize(ArchiveT& ar); +#ifdef LBANN_HAS_ONNX + virtual std::vector get_onnx_nodes() const; +#endif + ///@} /** @name Computational interface */ ///@{ @@ -164,7 +172,7 @@ class Operator : public AbstractCloneableBase>, virtual void set_proto_params(lbann_data::Operator&) const = 0; /** @brief Concrete operator description. */ virtual void do_fill_description(Description&) const = 0; -}; +}; // class Operator template void Operator::write_proto( @@ -208,5 +216,15 @@ template void Operator::serialize(ArchiveT& ar) {} +#ifdef LBANN_HAS_ONNX +template +std::vector Operator::get_onnx_nodes() const +{ + // The default assumption is that we don't know how to represent + // this operator in ONNX terms yet. + return {}; +} +#endif + } // namespace lbann #endif // LBANN_OPERATORS_OPERATOR_HPP_INCLUDED From a081d99c916d0dfc3a5d9c7c902a4a06638bf20c Mon Sep 17 00:00:00 2001 From: Katie Graham Date: Tue, 19 Apr 2022 09:47:41 -0700 Subject: [PATCH 2/5] added fill_onnx_node() to operator layer --- include/lbann/layers/operator_layer.hpp | 4 + .../operators/math/binary_with_constant.hpp | 95 ++++++++++++++----- src/layers/operator_layer.cpp | 59 ++++++++++++ 3 files changed, 135 insertions(+), 23 deletions(-) diff --git a/include/lbann/layers/operator_layer.hpp b/include/lbann/layers/operator_layer.hpp index 4546fa5fd8e..7daae25fe2e 100644 --- a/include/lbann/layers/operator_layer.hpp +++ b/include/lbann/layers/operator_layer.hpp @@ -81,6 +81,10 @@ class OperatorLayer final : public data_type_layer data_layout get_data_layout() const final; El::Device get_device_allocation() const final; +#ifdef LBANN_HAS_ONNX + void fill_onnx_node(onnx::GraphProto& graph) const override; +#endif //LBANN_HAS_ONNX + void fp_compute() final; void bp_compute() final; diff --git a/include/lbann/operators/math/binary_with_constant.hpp b/include/lbann/operators/math/binary_with_constant.hpp index 024e4f7f8c2..3ddb6821035 100644 --- a/include/lbann/operators/math/binary_with_constant.hpp +++ b/include/lbann/operators/math/binary_with_constant.hpp @@ -32,8 +32,9 @@ #include "lbann/operators/elementwise_operator.hpp" #include "lbann/utils/cloneable.hpp" +#ifdef LBANN_HAS_ONNX #include -#include +#endif // LBANN_HAS_ONNX /** @file @@ -165,6 +166,9 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant, inline onnx::NodeProto get_constant_node(float val) { onnx::NodeProto const_node; + const_node.add_output("const_val"); + const_node.set_domain(""); + const_node.set_doc_string("Const value for binary with constant operations"); auto* const_val = const_node.add_attribute(); const_val->set_name("value_float"); const_val->set_type(onnx::AttributeProto::FLOAT); @@ -185,79 +189,124 @@ std::vector get_onnx_nodes_impl( template std::vector get_onnx_nodes_impl( - ScaleOperator const) + ScaleOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PostConstant"); + nodes.back().set_op_type("Mul"); + return nodes; } template std::vector get_onnx_nodes_impl( - SubtractConstantOperator const) + SubtractConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PostConstant"); + nodes.back().set_op_type("Sub"); + return nodes; } template std::vector get_onnx_nodes_impl( - ConstantSubtractOperator const) + ConstantSubtractOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.back().set_op_type("Sub"); + return nodes; } template std::vector get_onnx_nodes_impl( - MaxConstantOperator const) + MaxConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.back().set_op_type("Max"); + return nodes; } template std::vector get_onnx_nodes_impl( - MinConstantOperator const) + MinConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.back().set_op_type("Min"); + return nodes; } template std::vector get_onnx_nodes_impl( - EqualConstantOperator const) + EqualConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.back().set_op_type("Equal"); + return nodes; } template std::vector get_onnx_nodes_impl( - NotEqualConstantOperator const) + NotEqualConstantOperator const op) { - return {}; + std::vector nodes(3UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.at(1).set_op_type("Not"); + nodes.back().set_op_type("Equal"); + return nodes; } template std::vector get_onnx_nodes_impl( - LessConstantOperator const) + LessConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PostConstant"); + nodes.back().set_op_type("Less"); + return nodes; } template std::vector get_onnx_nodes_impl( - LessEqualConstantOperator const) + LessEqualConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PostConstant"); + nodes.back().set_op_type("LessOrEqual"); + return nodes; } template std::vector get_onnx_nodes_impl( - GreaterConstantOperator const) + GreaterConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.back().set_op_type("Greater"); + return nodes; } template std::vector get_onnx_nodes_impl( - GreaterEqualConstantOperator const) + GreaterEqualConstantOperator const op) { - return {}; + std::vector nodes(2UL); + nodes.front() = get_constant_node(El::To(op.get_constant())); + nodes.front().set_op_type("PreConstant"); + nodes.back().set_op_type("GreaterOrEqual"); + return nodes; } } // namespace lbann diff --git a/src/layers/operator_layer.cpp b/src/layers/operator_layer.cpp index b3e91a514e2..938549f707c 100644 --- a/src/layers/operator_layer.cpp +++ b/src/layers/operator_layer.cpp @@ -57,4 +57,63 @@ void OperatorLayer::write_specific_proto( op->set_device_allocation(proto::ProtoDevice); } +#ifdef LBANN_HAS_ONNX +template +void OperatorLayer::fill_onnx_node( + onnx::GraphProto& graph) const +{ + std::vector nodes(2UL); + nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT); + nodes.front().add_attribute()->set_f(El::To(5)); + nodes.front().set_op_type("PostConstant"); + nodes.back().set_op_type("Add"); + + //OperatorPtr op; + //auto nodes = op->get_onnx_nodes(); + const auto* parent = this->get_parent_layers()[0]; + + auto* const_node = graph.add_node(); + *const_node = nodes.front(); + + auto* node = graph.add_node(); + *node = nodes.back(); + node->set_name(this->get_name()); + node->set_domain(""); + node->set_doc_string(this->get_name()); + if(const_node->op_type() == "PostConstant") + { + node->add_input(parent->get_name() + "_0"); + node->add_input(const_node->output(0)); + const_node->set_op_type("Constant"); + } + else if(const_node->op_type() == "PreConstant") + { + node->add_input(const_node->output(0)); + node->add_input(parent->get_name() + "_0"); + const_node->set_op_type("Constant"); + } + else + LBANN_ERROR("Unknown onnx op type for constant."); + + // Not equal operator + if(nodes.size() == 3) + { + node->add_output("EqualOperator"); + auto* not_node = graph.add_node(); + not_node->add_input(node->output(0)); + not_node->add_output(this->get_child_layers()[0]->get_name() + "_0"); + not_node->set_name("Not operator"); + not_node->set_op_type("Not"); + not_node->set_domain(""); + not_node->set_doc_string("Not node for not equal operation."); + } + else if(nodes.size() == 2) + { + node->add_output(this->get_child_layers()[0]->get_name() + "_0"); + } + else + LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size()); +} +#endif // LBANN_HAS_ONNX + } // namespace lbann From 0ead78a5e80caf2589095991bc68af686767d4f6 Mon Sep 17 00:00:00 2001 From: Katie Graham Date: Thu, 21 Apr 2022 19:08:15 -0700 Subject: [PATCH 3/5] added binary operators to fill_onnx_node, refactored binary_with constant for consistency with location of operator in vector for regular binary operators --- .../operators/math/binary_with_constant.hpp | 72 +++++++++---------- src/layers/operator_layer.cpp | 70 ++++++++++-------- 2 files changed, 75 insertions(+), 67 deletions(-) diff --git a/include/lbann/operators/math/binary_with_constant.hpp b/include/lbann/operators/math/binary_with_constant.hpp index 3ddb6821035..c4312bf6575 100644 --- a/include/lbann/operators/math/binary_with_constant.hpp +++ b/include/lbann/operators/math/binary_with_constant.hpp @@ -181,9 +181,9 @@ std::vector get_onnx_nodes_impl( AddConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PostConstant"); - nodes.back().set_op_type("Add"); + nodes.front().set_op_type("Add"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); return nodes; } @@ -192,9 +192,9 @@ std::vector get_onnx_nodes_impl( ScaleOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PostConstant"); - nodes.back().set_op_type("Mul"); + nodes.front().set_op_type("Mul"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); return nodes; } @@ -203,9 +203,9 @@ std::vector get_onnx_nodes_impl( SubtractConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PostConstant"); - nodes.back().set_op_type("Sub"); + nodes.front().set_op_type("Sub"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); return nodes; } @@ -214,9 +214,9 @@ std::vector get_onnx_nodes_impl( ConstantSubtractOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); - nodes.back().set_op_type("Sub"); + nodes.front().set_op_type("Sub"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); return nodes; } @@ -225,9 +225,9 @@ std::vector get_onnx_nodes_impl( MaxConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); - nodes.back().set_op_type("Max"); + nodes.front().set_op_type("Max"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); return nodes; } @@ -236,9 +236,9 @@ std::vector get_onnx_nodes_impl( MinConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); - nodes.back().set_op_type("Min"); + nodes.front().set_op_type("Min"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); return nodes; } @@ -247,9 +247,9 @@ std::vector get_onnx_nodes_impl( EqualConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); - nodes.back().set_op_type("Equal"); + nodes.front().set_op_type("Equal"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); return nodes; } @@ -258,10 +258,10 @@ std::vector get_onnx_nodes_impl( NotEqualConstantOperator const op) { std::vector nodes(3UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); + nodes.front().set_op_type("Equal"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); nodes.at(1).set_op_type("Not"); - nodes.back().set_op_type("Equal"); return nodes; } @@ -270,9 +270,9 @@ std::vector get_onnx_nodes_impl( LessConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PostConstant"); - nodes.back().set_op_type("Less"); + nodes.front().set_op_type("Less"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); return nodes; } @@ -281,9 +281,9 @@ std::vector get_onnx_nodes_impl( LessEqualConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PostConstant"); - nodes.back().set_op_type("LessOrEqual"); + nodes.front().set_op_type("LessOrEqual"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); return nodes; } @@ -292,9 +292,9 @@ std::vector get_onnx_nodes_impl( GreaterConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); - nodes.back().set_op_type("Greater"); + nodes.front().set_op_type("Greater"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); return nodes; } @@ -303,9 +303,9 @@ std::vector get_onnx_nodes_impl( GreaterEqualConstantOperator const op) { std::vector nodes(2UL); - nodes.front() = get_constant_node(El::To(op.get_constant())); - nodes.front().set_op_type("PreConstant"); - nodes.back().set_op_type("GreaterOrEqual"); + nodes.front().set_op_type("GreaterOrEqual"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); return nodes; } diff --git a/src/layers/operator_layer.cpp b/src/layers/operator_layer.cpp index 938549f707c..2d5b2355334 100644 --- a/src/layers/operator_layer.cpp +++ b/src/layers/operator_layer.cpp @@ -62,57 +62,65 @@ template void OperatorLayer::fill_onnx_node( onnx::GraphProto& graph) const { - std::vector nodes(2UL); - nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT); - nodes.front().add_attribute()->set_f(El::To(5)); - nodes.front().set_op_type("PostConstant"); - nodes.back().set_op_type("Add"); + const auto& parents = this->get_parent_layers(); + auto nodes = m_ops.front()->get_onnx_nodes(); - //OperatorPtr op; - //auto nodes = op->get_onnx_nodes(); - const auto* parent = this->get_parent_layers()[0]; + auto* op_node = graph.add_node(); + *op_node = nodes.front(); - auto* const_node = graph.add_node(); - *const_node = nodes.front(); + op_node->set_name(this->get_name()); + op_node->set_domain(""); + op_node->set_doc_string(this->get_name()); - auto* node = graph.add_node(); - *node = nodes.back(); - node->set_name(this->get_name()); - node->set_domain(""); - node->set_doc_string(this->get_name()); - if(const_node->op_type() == "PostConstant") + //binary operators + if(nodes.size() == 1) { - node->add_input(parent->get_name() + "_0"); - node->add_input(const_node->output(0)); - const_node->set_op_type("Constant"); + for(auto* parent : parents) + { + size_t idx = parent->find_child_layer_index(*this); + op_node->add_input(parent->get_name() + "_" + std::to_string(idx)); + } } - else if(const_node->op_type() == "PreConstant") + // Binary w/ constant operators + else if(nodes.size() == 2 || nodes.size() == 3) { - node->add_input(const_node->output(0)); - node->add_input(parent->get_name() + "_0"); + auto* const_node = graph.add_node(); + *const_node = nodes.back(); + if(const_node->op_type() == "PostConstant") + { + op_node->add_input(parents[0]->get_name() + "_0"); + op_node->add_input(const_node->output(0)); + } + else if(const_node->op_type() == "PreConstant") + { + op_node->add_input(const_node->output(0)); + op_node->add_input(parents[0]->get_name() + "_0"); + } + else + LBANN_ERROR("Unknown onnx op type for constant."); + const_node->set_op_type("Constant"); } else - LBANN_ERROR("Unknown onnx op type for constant."); + LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size()); // Not equal operator if(nodes.size() == 3) { - node->add_output("EqualOperator"); + op_node->add_output("EqualOperator"); auto* not_node = graph.add_node(); - not_node->add_input(node->output(0)); - not_node->add_output(this->get_child_layers()[0]->get_name() + "_0"); + not_node->add_input(op_node->output(0)); not_node->set_name("Not operator"); not_node->set_op_type("Not"); not_node->set_domain(""); not_node->set_doc_string("Not node for not equal operation."); + op_node = not_node; } - else if(nodes.size() == 2) - { - node->add_output(this->get_child_layers()[0]->get_name() + "_0"); + + for (auto const* child : this->get_child_layers()) { + auto idx = this->find_child_layer_index(*child); + op_node->add_output(this->get_name() + "_" + std::to_string(idx)); } - else - LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size()); } #endif // LBANN_HAS_ONNX From 41766ae0cf89da427f604d6acf4132035e2214d6 Mon Sep 17 00:00:00 2001 From: "Thomas R. Benson" Date: Wed, 18 May 2022 10:32:23 -0400 Subject: [PATCH 4/5] Clang-format --- include/lbann/layers/operator_layer.hpp | 2 +- .../operators/math/binary_with_constant.hpp | 47 +++++++++---------- include/lbann/operators/operator.hpp | 3 +- src/layers/operator_layer.cpp | 26 ++++------ 4 files changed, 36 insertions(+), 42 deletions(-) diff --git a/include/lbann/layers/operator_layer.hpp b/include/lbann/layers/operator_layer.hpp index 7daae25fe2e..8929f4520d8 100644 --- a/include/lbann/layers/operator_layer.hpp +++ b/include/lbann/layers/operator_layer.hpp @@ -83,7 +83,7 @@ class OperatorLayer final : public data_type_layer #ifdef LBANN_HAS_ONNX void fill_onnx_node(onnx::GraphProto& graph) const override; -#endif //LBANN_HAS_ONNX +#endif // LBANN_HAS_ONNX void fp_compute() final; void bp_compute() final; diff --git a/include/lbann/operators/math/binary_with_constant.hpp b/include/lbann/operators/math/binary_with_constant.hpp index c4312bf6575..9499ccfe024 100644 --- a/include/lbann/operators/math/binary_with_constant.hpp +++ b/include/lbann/operators/math/binary_with_constant.hpp @@ -177,8 +177,8 @@ inline onnx::NodeProto get_constant_node(float val) } template -std::vector get_onnx_nodes_impl( - AddConstantOperator const op) +std::vector +get_onnx_nodes_impl(AddConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Add"); @@ -188,8 +188,7 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - ScaleOperator const op) +std::vector get_onnx_nodes_impl(ScaleOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Mul"); @@ -199,8 +198,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - SubtractConstantOperator const op) +std::vector +get_onnx_nodes_impl(SubtractConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Sub"); @@ -210,8 +209,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - ConstantSubtractOperator const op) +std::vector +get_onnx_nodes_impl(ConstantSubtractOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Sub"); @@ -221,8 +220,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - MaxConstantOperator const op) +std::vector +get_onnx_nodes_impl(MaxConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Max"); @@ -232,8 +231,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - MinConstantOperator const op) +std::vector +get_onnx_nodes_impl(MinConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Min"); @@ -243,8 +242,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - EqualConstantOperator const op) +std::vector +get_onnx_nodes_impl(EqualConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Equal"); @@ -254,8 +253,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - NotEqualConstantOperator const op) +std::vector +get_onnx_nodes_impl(NotEqualConstantOperator const op) { std::vector nodes(3UL); nodes.front().set_op_type("Equal"); @@ -266,8 +265,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - LessConstantOperator const op) +std::vector +get_onnx_nodes_impl(LessConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Less"); @@ -277,8 +276,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - LessEqualConstantOperator const op) +std::vector +get_onnx_nodes_impl(LessEqualConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("LessOrEqual"); @@ -288,8 +287,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - GreaterConstantOperator const op) +std::vector +get_onnx_nodes_impl(GreaterConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("Greater"); @@ -299,8 +298,8 @@ std::vector get_onnx_nodes_impl( } template -std::vector get_onnx_nodes_impl( - GreaterEqualConstantOperator const op) +std::vector +get_onnx_nodes_impl(GreaterEqualConstantOperator const op) { std::vector nodes(2UL); nodes.front().set_op_type("GreaterOrEqual"); diff --git a/include/lbann/operators/operator.hpp b/include/lbann/operators/operator.hpp index e9a092a19ba..d1e53fbf1cf 100644 --- a/include/lbann/operators/operator.hpp +++ b/include/lbann/operators/operator.hpp @@ -218,7 +218,8 @@ void Operator::serialize(ArchiveT& ar) #ifdef LBANN_HAS_ONNX template -std::vector Operator::get_onnx_nodes() const +std::vector +Operator::get_onnx_nodes() const { // The default assumption is that we don't know how to represent // this operator in ONNX terms yet. diff --git a/src/layers/operator_layer.cpp b/src/layers/operator_layer.cpp index 2d5b2355334..5d007590de4 100644 --- a/src/layers/operator_layer.cpp +++ b/src/layers/operator_layer.cpp @@ -59,8 +59,7 @@ void OperatorLayer::write_specific_proto( #ifdef LBANN_HAS_ONNX template -void OperatorLayer::fill_onnx_node( - onnx::GraphProto& graph) const +void OperatorLayer::fill_onnx_node(onnx::GraphProto& graph) const { const auto& parents = this->get_parent_layers(); auto nodes = m_ops.front()->get_onnx_nodes(); @@ -72,27 +71,22 @@ void OperatorLayer::fill_onnx_node( op_node->set_domain(""); op_node->set_doc_string(this->get_name()); - //binary operators - if(nodes.size() == 1) - { - for(auto* parent : parents) - { + // binary operators + if (nodes.size() == 1) { + for (auto* parent : parents) { size_t idx = parent->find_child_layer_index(*this); op_node->add_input(parent->get_name() + "_" + std::to_string(idx)); } } // Binary w/ constant operators - else if(nodes.size() == 2 || nodes.size() == 3) - { + else if (nodes.size() == 2 || nodes.size() == 3) { auto* const_node = graph.add_node(); *const_node = nodes.back(); - if(const_node->op_type() == "PostConstant") - { + if (const_node->op_type() == "PostConstant") { op_node->add_input(parents[0]->get_name() + "_0"); op_node->add_input(const_node->output(0)); } - else if(const_node->op_type() == "PreConstant") - { + else if (const_node->op_type() == "PreConstant") { op_node->add_input(const_node->output(0)); op_node->add_input(parents[0]->get_name() + "_0"); } @@ -102,11 +96,11 @@ void OperatorLayer::fill_onnx_node( const_node->set_op_type("Constant"); } else - LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size()); + LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", + nodes.size()); // Not equal operator - if(nodes.size() == 3) - { + if (nodes.size() == 3) { op_node->add_output("EqualOperator"); auto* not_node = graph.add_node(); not_node->add_input(op_node->output(0)); From cb9b040f4036623b6cdc0899479c8725ee6bf4bc Mon Sep 17 00:00:00 2001 From: Katie Graham Date: Tue, 21 Mar 2023 14:25:27 -0700 Subject: [PATCH 5/5] Resolved errors introduced by include-what-you-use and various other PRs --- include/lbann/operators/declare_stateless_op.hpp | 2 +- include/lbann/operators/math/binary_with_constant.hpp | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/lbann/operators/declare_stateless_op.hpp b/include/lbann/operators/declare_stateless_op.hpp index dc5d42eeb26..10a0de0431a 100644 --- a/include/lbann/operators/declare_stateless_op.hpp +++ b/include/lbann/operators/declare_stateless_op.hpp @@ -140,7 +140,7 @@ } \ void do_fill_description(description&) const final \ {} \ - } + }; namespace lbann { diff --git a/include/lbann/operators/math/binary_with_constant.hpp b/include/lbann/operators/math/binary_with_constant.hpp index 9499ccfe024..e021d9e1438 100644 --- a/include/lbann/operators/math/binary_with_constant.hpp +++ b/include/lbann/operators/math/binary_with_constant.hpp @@ -33,7 +33,7 @@ #include "lbann/utils/cloneable.hpp" #ifdef LBANN_HAS_ONNX -#include +#include #endif // LBANN_HAS_ONNX @@ -163,6 +163,7 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterEqualConstant, LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant, "greater than constant"); +#ifdef LBANN_HAS_ONNX inline onnx::NodeProto get_constant_node(float val) { onnx::NodeProto const_node; @@ -307,6 +308,7 @@ get_onnx_nodes_impl(GreaterEqualConstantOperator const op) nodes.back().set_op_type("PreConstant"); return nodes; } +#endif // LBANN_HAS_ONNX } // namespace lbann #endif // LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED