3232#include " lbann/operators/elementwise_operator.hpp"
3333#include " lbann/utils/cloneable.hpp"
3434
35+ #ifdef LBANN_HAS_ONNX
3536#include < onnx/onnx-ml.pb.h>
36- #include < operators.pb.h >
37+ #endif // LBANN_HAS_ONNX
3738
3839/* * @file
3940 *
@@ -164,6 +165,9 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant,
164165inline onnx::NodeProto get_constant_node (float val)
165166{
166167 onnx::NodeProto const_node;
168+ const_node.add_output (" const_val" );
169+ const_node.set_domain (" " );
170+ const_node.set_doc_string (" Const value for binary with constant operations" );
167171 auto * const_val = const_node.add_attribute ();
168172 const_val->set_name (" value_float" );
169173 const_val->set_type (onnx::AttributeProto::FLOAT);
@@ -184,79 +188,124 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
184188
185189template <typename T, El::Device D>
186190std::vector<onnx::NodeProto> get_onnx_nodes_impl (
187- ScaleOperator<T, D> const )
191+ ScaleOperator<T, D> const op )
188192{
189- return {};
193+ std::vector<onnx::NodeProto> nodes (2UL );
194+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
195+ nodes.front ().set_op_type (" PostConstant" );
196+ nodes.back ().set_op_type (" Mul" );
197+ return nodes;
190198}
191199
192200template <typename T, El::Device D>
193201std::vector<onnx::NodeProto> get_onnx_nodes_impl (
194- SubtractConstantOperator<T, D> const )
202+ SubtractConstantOperator<T, D> const op )
195203{
196- return {};
204+ std::vector<onnx::NodeProto> nodes (2UL );
205+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
206+ nodes.front ().set_op_type (" PostConstant" );
207+ nodes.back ().set_op_type (" Sub" );
208+ return nodes;
197209}
198210
199211template <typename T, El::Device D>
200212std::vector<onnx::NodeProto> get_onnx_nodes_impl (
201- ConstantSubtractOperator<T, D> const )
213+ ConstantSubtractOperator<T, D> const op )
202214{
203- return {};
215+ std::vector<onnx::NodeProto> nodes (2UL );
216+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
217+ nodes.front ().set_op_type (" PreConstant" );
218+ nodes.back ().set_op_type (" Sub" );
219+ return nodes;
204220}
205221
206222template <typename T, El::Device D>
207223std::vector<onnx::NodeProto> get_onnx_nodes_impl (
208- MaxConstantOperator<T, D> const )
224+ MaxConstantOperator<T, D> const op )
209225{
210- return {};
226+ std::vector<onnx::NodeProto> nodes (2UL );
227+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
228+ nodes.front ().set_op_type (" PreConstant" );
229+ nodes.back ().set_op_type (" Max" );
230+ return nodes;
211231}
212232
213233template <typename T, El::Device D>
214234std::vector<onnx::NodeProto> get_onnx_nodes_impl (
215- MinConstantOperator<T, D> const )
235+ MinConstantOperator<T, D> const op )
216236{
217- return {};
237+ std::vector<onnx::NodeProto> nodes (2UL );
238+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
239+ nodes.front ().set_op_type (" PreConstant" );
240+ nodes.back ().set_op_type (" Min" );
241+ return nodes;
218242}
219243
220244template <typename T, El::Device D>
221245std::vector<onnx::NodeProto> get_onnx_nodes_impl (
222- EqualConstantOperator<T, D> const )
246+ EqualConstantOperator<T, D> const op )
223247{
224- return {};
248+ std::vector<onnx::NodeProto> nodes (2UL );
249+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
250+ nodes.front ().set_op_type (" PreConstant" );
251+ nodes.back ().set_op_type (" Equal" );
252+ return nodes;
225253}
226254
227255template <typename T, El::Device D>
228256std::vector<onnx::NodeProto> get_onnx_nodes_impl (
229- NotEqualConstantOperator<T, D> const )
257+ NotEqualConstantOperator<T, D> const op )
230258{
231- return {};
259+ std::vector<onnx::NodeProto> nodes (3UL );
260+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
261+ nodes.front ().set_op_type (" PreConstant" );
262+ nodes.at (1 ).set_op_type (" Not" );
263+ nodes.back ().set_op_type (" Equal" );
264+ return nodes;
232265}
233266
234267template <typename T, El::Device D>
235268std::vector<onnx::NodeProto> get_onnx_nodes_impl (
236- LessConstantOperator<T, D> const )
269+ LessConstantOperator<T, D> const op )
237270{
238- return {};
271+ std::vector<onnx::NodeProto> nodes (2UL );
272+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
273+ nodes.front ().set_op_type (" PostConstant" );
274+ nodes.back ().set_op_type (" Less" );
275+ return nodes;
239276}
240277
241278template <typename T, El::Device D>
242279std::vector<onnx::NodeProto> get_onnx_nodes_impl (
243- LessEqualConstantOperator<T, D> const )
280+ LessEqualConstantOperator<T, D> const op )
244281{
245- return {};
282+ std::vector<onnx::NodeProto> nodes (2UL );
283+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
284+ nodes.front ().set_op_type (" PostConstant" );
285+ nodes.back ().set_op_type (" LessOrEqual" );
286+ return nodes;
246287}
247288
248289template <typename T, El::Device D>
249290std::vector<onnx::NodeProto> get_onnx_nodes_impl (
250- GreaterConstantOperator<T, D> const )
291+ GreaterConstantOperator<T, D> const op )
251292{
252- return {};
293+ std::vector<onnx::NodeProto> nodes (2UL );
294+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
295+ nodes.front ().set_op_type (" PreConstant" );
296+ nodes.back ().set_op_type (" Greater" );
297+ return nodes;
253298}
254299
255300template <typename T, El::Device D>
256301std::vector<onnx::NodeProto> get_onnx_nodes_impl (
257- GreaterEqualConstantOperator<T, D> const )
302+ GreaterEqualConstantOperator<T, D> const op )
258303{
259- return {};
304+ std::vector<onnx::NodeProto> nodes (2UL );
305+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
306+ nodes.front ().set_op_type (" PreConstant" );
307+ nodes.back ().set_op_type (" GreaterOrEqual" );
308+ return nodes;
260309}
261310
262311} // namespace lbann
0 commit comments