Skip to content

Commit 24c3efe

Browse files
committed
added fill_onnx_node() to operator layer
1 parent a90ee88 commit 24c3efe

File tree

3 files changed

+135
-23
lines changed

3 files changed

+135
-23
lines changed

include/lbann/layers/operator_layer.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ class OperatorLayer final : public data_type_layer<InputT, OutputT>
8181
data_layout get_data_layout() const final;
8282
El::Device get_device_allocation() const final;
8383

84+
#ifdef LBANN_HAS_ONNX
85+
void fill_onnx_node(onnx::GraphProto& graph) const override;
86+
#endif //LBANN_HAS_ONNX
87+
8488
void fp_compute() final;
8589
void bp_compute() final;
8690

include/lbann/operators/math/binary_with_constant.hpp

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
#include "lbann/operators/elementwise_operator.hpp"
3333
#include "lbann/utils/cloneable.hpp"
3434

35+
#ifdef LBANN_HAS_ONNX
3536
#include <onnx/onnx-ml.pb.h>
36-
#include <operators.pb.h>
37+
#endif // LBANN_HAS_ONNX
3738

3839

3940
/** @file
@@ -165,6 +166,9 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant,
165166
inline onnx::NodeProto get_constant_node(float val)
166167
{
167168
onnx::NodeProto const_node;
169+
const_node.add_output("const_val");
170+
const_node.set_domain("");
171+
const_node.set_doc_string("Const value for binary with constant operations");
168172
auto* const_val = const_node.add_attribute();
169173
const_val->set_name("value_float");
170174
const_val->set_type(onnx::AttributeProto::FLOAT);
@@ -185,79 +189,124 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
185189

186190
template <typename T, El::Device D>
187191
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
188-
ScaleOperator<T, D> const)
192+
ScaleOperator<T, D> const op)
189193
{
190-
return {};
194+
std::vector<onnx::NodeProto> nodes(2UL);
195+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
196+
nodes.front().set_op_type("PostConstant");
197+
nodes.back().set_op_type("Mul");
198+
return nodes;
191199
}
192200

193201
template <typename T, El::Device D>
194202
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
195-
SubtractConstantOperator<T, D> const)
203+
SubtractConstantOperator<T, D> const op)
196204
{
197-
return {};
205+
std::vector<onnx::NodeProto> nodes(2UL);
206+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
207+
nodes.front().set_op_type("PostConstant");
208+
nodes.back().set_op_type("Sub");
209+
return nodes;
198210
}
199211

200212
template <typename T, El::Device D>
201213
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
202-
ConstantSubtractOperator<T, D> const)
214+
ConstantSubtractOperator<T, D> const op)
203215
{
204-
return {};
216+
std::vector<onnx::NodeProto> nodes(2UL);
217+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
218+
nodes.front().set_op_type("PreConstant");
219+
nodes.back().set_op_type("Sub");
220+
return nodes;
205221
}
206222

207223
template <typename T, El::Device D>
208224
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
209-
MaxConstantOperator<T, D> const)
225+
MaxConstantOperator<T, D> const op)
210226
{
211-
return {};
227+
std::vector<onnx::NodeProto> nodes(2UL);
228+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
229+
nodes.front().set_op_type("PreConstant");
230+
nodes.back().set_op_type("Max");
231+
return nodes;
212232
}
213233

214234
template <typename T, El::Device D>
215235
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
216-
MinConstantOperator<T, D> const)
236+
MinConstantOperator<T, D> const op)
217237
{
218-
return {};
238+
std::vector<onnx::NodeProto> nodes(2UL);
239+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
240+
nodes.front().set_op_type("PreConstant");
241+
nodes.back().set_op_type("Min");
242+
return nodes;
219243
}
220244

221245
template <typename T, El::Device D>
222246
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
223-
EqualConstantOperator<T, D> const)
247+
EqualConstantOperator<T, D> const op)
224248
{
225-
return {};
249+
std::vector<onnx::NodeProto> nodes(2UL);
250+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
251+
nodes.front().set_op_type("PreConstant");
252+
nodes.back().set_op_type("Equal");
253+
return nodes;
226254
}
227255

228256
template <typename T, El::Device D>
229257
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
230-
NotEqualConstantOperator<T, D> const)
258+
NotEqualConstantOperator<T, D> const op)
231259
{
232-
return {};
260+
std::vector<onnx::NodeProto> nodes(3UL);
261+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
262+
nodes.front().set_op_type("PreConstant");
263+
nodes.at(1).set_op_type("Not");
264+
nodes.back().set_op_type("Equal");
265+
return nodes;
233266
}
234267

235268
template <typename T, El::Device D>
236269
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
237-
LessConstantOperator<T, D> const)
270+
LessConstantOperator<T, D> const op)
238271
{
239-
return {};
272+
std::vector<onnx::NodeProto> nodes(2UL);
273+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
274+
nodes.front().set_op_type("PostConstant");
275+
nodes.back().set_op_type("Less");
276+
return nodes;
240277
}
241278

242279
template <typename T, El::Device D>
243280
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
244-
LessEqualConstantOperator<T, D> const)
281+
LessEqualConstantOperator<T, D> const op)
245282
{
246-
return {};
283+
std::vector<onnx::NodeProto> nodes(2UL);
284+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
285+
nodes.front().set_op_type("PostConstant");
286+
nodes.back().set_op_type("LessOrEqual");
287+
return nodes;
247288
}
248289

249290
template <typename T, El::Device D>
250291
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
251-
GreaterConstantOperator<T, D> const)
292+
GreaterConstantOperator<T, D> const op)
252293
{
253-
return {};
294+
std::vector<onnx::NodeProto> nodes(2UL);
295+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
296+
nodes.front().set_op_type("PreConstant");
297+
nodes.back().set_op_type("Greater");
298+
return nodes;
254299
}
255300

256301
template <typename T, El::Device D>
257302
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
258-
GreaterEqualConstantOperator<T, D> const)
303+
GreaterEqualConstantOperator<T, D> const op)
259304
{
260-
return {};
305+
std::vector<onnx::NodeProto> nodes(2UL);
306+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
307+
nodes.front().set_op_type("PreConstant");
308+
nodes.back().set_op_type("GreaterOrEqual");
309+
return nodes;
261310
}
262311

263312
} // namespace lbann

src/layers/operator_layer.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,63 @@ void OperatorLayer<T, O, L, D>::write_specific_proto(lbann_data::Layer& proto) c
5555
op->set_device_allocation(proto::ProtoDevice<D>);
5656
}
5757

58+
#ifdef LBANN_HAS_ONNX
59+
template <typename T, typename O, data_layout L, El::Device D>
60+
void OperatorLayer<T, O, L, D>::fill_onnx_node(
61+
onnx::GraphProto& graph) const
62+
{
63+
std::vector<onnx::NodeProto> nodes(2UL);
64+
nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT);
65+
nodes.front().add_attribute()->set_f(El::To<float>(5));
66+
nodes.front().set_op_type("PostConstant");
67+
nodes.back().set_op_type("Add");
68+
69+
//OperatorPtr op;
70+
//auto nodes = op->get_onnx_nodes();
71+
const auto* parent = this->get_parent_layers()[0];
72+
73+
auto* const_node = graph.add_node();
74+
*const_node = nodes.front();
75+
76+
auto* node = graph.add_node();
77+
*node = nodes.back();
78+
node->set_name(this->get_name());
79+
node->set_domain("");
80+
node->set_doc_string(this->get_name());
81+
if(const_node->op_type() == "PostConstant")
82+
{
83+
node->add_input(parent->get_name() + "_0");
84+
node->add_input(const_node->output(0));
85+
const_node->set_op_type("Constant");
86+
}
87+
else if(const_node->op_type() == "PreConstant")
88+
{
89+
node->add_input(const_node->output(0));
90+
node->add_input(parent->get_name() + "_0");
91+
const_node->set_op_type("Constant");
92+
}
93+
else
94+
LBANN_ERROR("Unknown onnx op type for constant.");
95+
96+
// Not equal operator
97+
if(nodes.size() == 3)
98+
{
99+
node->add_output("EqualOperator");
100+
auto* not_node = graph.add_node();
101+
not_node->add_input(node->output(0));
102+
not_node->add_output(this->get_child_layers()[0]->get_name() + "_0");
103+
not_node->set_name("Not operator");
104+
not_node->set_op_type("Not");
105+
not_node->set_domain("");
106+
not_node->set_doc_string("Not node for not equal operation.");
107+
}
108+
else if(nodes.size() == 2)
109+
{
110+
node->add_output(this->get_child_layers()[0]->get_name() + "_0");
111+
}
112+
else
113+
LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size());
114+
}
115+
#endif // LBANN_HAS_ONNX
116+
58117
} // namespace lbann

0 commit comments

Comments
 (0)