3232#include " lbann/operators/elementwise_operator.hpp"
3333#include " lbann/utils/cloneable.hpp"
3434
35+ #ifdef LBANN_HAS_ONNX
3536#include < onnx/onnx-ml.pb.h>
36- #include < operators.pb.h >
37+ #endif // LBANN_HAS_ONNX
3738
3839
3940/* * @file
@@ -165,6 +166,9 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant,
165166inline onnx::NodeProto get_constant_node (float val)
166167{
167168 onnx::NodeProto const_node;
169+ const_node.add_output (" const_val" );
170+ const_node.set_domain (" " );
171+ const_node.set_doc_string (" Const value for binary with constant operations" );
168172 auto * const_val = const_node.add_attribute ();
169173 const_val->set_name (" value_float" );
170174 const_val->set_type (onnx::AttributeProto::FLOAT);
@@ -185,79 +189,124 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
185189
186190template <typename T, El::Device D>
187191std::vector<onnx::NodeProto> get_onnx_nodes_impl (
188- ScaleOperator<T, D> const )
192+ ScaleOperator<T, D> const op )
189193{
190- return {};
194+ std::vector<onnx::NodeProto> nodes (2UL );
195+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
196+ nodes.front ().set_op_type (" PostConstant" );
197+ nodes.back ().set_op_type (" Mul" );
198+ return nodes;
191199}
192200
193201template <typename T, El::Device D>
194202std::vector<onnx::NodeProto> get_onnx_nodes_impl (
195- SubtractConstantOperator<T, D> const )
203+ SubtractConstantOperator<T, D> const op )
196204{
197- return {};
205+ std::vector<onnx::NodeProto> nodes (2UL );
206+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
207+ nodes.front ().set_op_type (" PostConstant" );
208+ nodes.back ().set_op_type (" Sub" );
209+ return nodes;
198210}
199211
200212template <typename T, El::Device D>
201213std::vector<onnx::NodeProto> get_onnx_nodes_impl (
202- ConstantSubtractOperator<T, D> const )
214+ ConstantSubtractOperator<T, D> const op )
203215{
204- return {};
216+ std::vector<onnx::NodeProto> nodes (2UL );
217+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
218+ nodes.front ().set_op_type (" PreConstant" );
219+ nodes.back ().set_op_type (" Sub" );
220+ return nodes;
205221}
206222
207223template <typename T, El::Device D>
208224std::vector<onnx::NodeProto> get_onnx_nodes_impl (
209- MaxConstantOperator<T, D> const )
225+ MaxConstantOperator<T, D> const op )
210226{
211- return {};
227+ std::vector<onnx::NodeProto> nodes (2UL );
228+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
229+ nodes.front ().set_op_type (" PreConstant" );
230+ nodes.back ().set_op_type (" Max" );
231+ return nodes;
212232}
213233
214234template <typename T, El::Device D>
215235std::vector<onnx::NodeProto> get_onnx_nodes_impl (
216- MinConstantOperator<T, D> const )
236+ MinConstantOperator<T, D> const op )
217237{
218- return {};
238+ std::vector<onnx::NodeProto> nodes (2UL );
239+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
240+ nodes.front ().set_op_type (" PreConstant" );
241+ nodes.back ().set_op_type (" Min" );
242+ return nodes;
219243}
220244
221245template <typename T, El::Device D>
222246std::vector<onnx::NodeProto> get_onnx_nodes_impl (
223- EqualConstantOperator<T, D> const )
247+ EqualConstantOperator<T, D> const op )
224248{
225- return {};
249+ std::vector<onnx::NodeProto> nodes (2UL );
250+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
251+ nodes.front ().set_op_type (" PreConstant" );
252+ nodes.back ().set_op_type (" Equal" );
253+ return nodes;
226254}
227255
228256template <typename T, El::Device D>
229257std::vector<onnx::NodeProto> get_onnx_nodes_impl (
230- NotEqualConstantOperator<T, D> const )
258+ NotEqualConstantOperator<T, D> const op )
231259{
232- return {};
260+ std::vector<onnx::NodeProto> nodes (3UL );
261+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
262+ nodes.front ().set_op_type (" PreConstant" );
263+ nodes.at (1 ).set_op_type (" Not" );
264+ nodes.back ().set_op_type (" Equal" );
265+ return nodes;
233266}
234267
235268template <typename T, El::Device D>
236269std::vector<onnx::NodeProto> get_onnx_nodes_impl (
237- LessConstantOperator<T, D> const )
270+ LessConstantOperator<T, D> const op )
238271{
239- return {};
272+ std::vector<onnx::NodeProto> nodes (2UL );
273+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
274+ nodes.front ().set_op_type (" PostConstant" );
275+ nodes.back ().set_op_type (" Less" );
276+ return nodes;
240277}
241278
242279template <typename T, El::Device D>
243280std::vector<onnx::NodeProto> get_onnx_nodes_impl (
244- LessEqualConstantOperator<T, D> const )
281+ LessEqualConstantOperator<T, D> const op )
245282{
246- return {};
283+ std::vector<onnx::NodeProto> nodes (2UL );
284+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
285+ nodes.front ().set_op_type (" PostConstant" );
286+ nodes.back ().set_op_type (" LessOrEqual" );
287+ return nodes;
247288}
248289
249290template <typename T, El::Device D>
250291std::vector<onnx::NodeProto> get_onnx_nodes_impl (
251- GreaterConstantOperator<T, D> const )
292+ GreaterConstantOperator<T, D> const op )
252293{
253- return {};
294+ std::vector<onnx::NodeProto> nodes (2UL );
295+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
296+ nodes.front ().set_op_type (" PreConstant" );
297+ nodes.back ().set_op_type (" Greater" );
298+ return nodes;
254299}
255300
256301template <typename T, El::Device D>
257302std::vector<onnx::NodeProto> get_onnx_nodes_impl (
258- GreaterEqualConstantOperator<T, D> const )
303+ GreaterEqualConstantOperator<T, D> const op )
259304{
260- return {};
305+ std::vector<onnx::NodeProto> nodes (2UL );
306+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
307+ nodes.front ().set_op_type (" PreConstant" );
308+ nodes.back ().set_op_type (" GreaterOrEqual" );
309+ return nodes;
261310}
262311
263312} // namespace lbann
0 commit comments