Skip to content

Commit 25a743d

Browse files
committed
added binary operators to fill_onnx_node, refactored binary_with constant for consistency with location of operator in vector for regular binary operators
1 parent 24c3efe commit 25a743d

File tree

2 files changed

+75
-67
lines changed

2 files changed

+75
-67
lines changed

include/lbann/operators/math/binary_with_constant.hpp

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
181181
AddConstantOperator<T, D> const op)
182182
{
183183
std::vector<onnx::NodeProto> nodes(2UL);
184-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
185-
nodes.front().set_op_type("PostConstant");
186-
nodes.back().set_op_type("Add");
184+
nodes.front().set_op_type("Add");
185+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
186+
nodes.back().set_op_type("PostConstant");
187187
return nodes;
188188
}
189189

@@ -192,9 +192,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
192192
ScaleOperator<T, D> const op)
193193
{
194194
std::vector<onnx::NodeProto> nodes(2UL);
195-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
196-
nodes.front().set_op_type("PostConstant");
197-
nodes.back().set_op_type("Mul");
195+
nodes.front().set_op_type("Mul");
196+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
197+
nodes.back().set_op_type("PostConstant");
198198
return nodes;
199199
}
200200

@@ -203,9 +203,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
203203
SubtractConstantOperator<T, D> const op)
204204
{
205205
std::vector<onnx::NodeProto> nodes(2UL);
206-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
207-
nodes.front().set_op_type("PostConstant");
208-
nodes.back().set_op_type("Sub");
206+
nodes.front().set_op_type("Sub");
207+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
208+
nodes.back().set_op_type("PostConstant");
209209
return nodes;
210210
}
211211

@@ -214,9 +214,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
214214
ConstantSubtractOperator<T, D> const op)
215215
{
216216
std::vector<onnx::NodeProto> nodes(2UL);
217-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
218-
nodes.front().set_op_type("PreConstant");
219-
nodes.back().set_op_type("Sub");
217+
nodes.front().set_op_type("Sub");
218+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
219+
nodes.back().set_op_type("PreConstant");
220220
return nodes;
221221
}
222222

@@ -225,9 +225,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
225225
MaxConstantOperator<T, D> const op)
226226
{
227227
std::vector<onnx::NodeProto> nodes(2UL);
228-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
229-
nodes.front().set_op_type("PreConstant");
230-
nodes.back().set_op_type("Max");
228+
nodes.front().set_op_type("Max");
229+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
230+
nodes.back().set_op_type("PreConstant");
231231
return nodes;
232232
}
233233

@@ -236,9 +236,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
236236
MinConstantOperator<T, D> const op)
237237
{
238238
std::vector<onnx::NodeProto> nodes(2UL);
239-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
240-
nodes.front().set_op_type("PreConstant");
241-
nodes.back().set_op_type("Min");
239+
nodes.front().set_op_type("Min");
240+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
241+
nodes.back().set_op_type("PreConstant");
242242
return nodes;
243243
}
244244

@@ -247,9 +247,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
247247
EqualConstantOperator<T, D> const op)
248248
{
249249
std::vector<onnx::NodeProto> nodes(2UL);
250-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
251-
nodes.front().set_op_type("PreConstant");
252-
nodes.back().set_op_type("Equal");
250+
nodes.front().set_op_type("Equal");
251+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
252+
nodes.back().set_op_type("PreConstant");
253253
return nodes;
254254
}
255255

@@ -258,10 +258,10 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
258258
NotEqualConstantOperator<T, D> const op)
259259
{
260260
std::vector<onnx::NodeProto> nodes(3UL);
261-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
262-
nodes.front().set_op_type("PreConstant");
261+
nodes.front().set_op_type("Equal");
262+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
263+
nodes.back().set_op_type("PreConstant");
263264
nodes.at(1).set_op_type("Not");
264-
nodes.back().set_op_type("Equal");
265265
return nodes;
266266
}
267267

@@ -270,9 +270,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
270270
LessConstantOperator<T, D> const op)
271271
{
272272
std::vector<onnx::NodeProto> nodes(2UL);
273-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
274-
nodes.front().set_op_type("PostConstant");
275-
nodes.back().set_op_type("Less");
273+
nodes.front().set_op_type("Less");
274+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
275+
nodes.back().set_op_type("PostConstant");
276276
return nodes;
277277
}
278278

@@ -281,9 +281,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
281281
LessEqualConstantOperator<T, D> const op)
282282
{
283283
std::vector<onnx::NodeProto> nodes(2UL);
284-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
285-
nodes.front().set_op_type("PostConstant");
286-
nodes.back().set_op_type("LessOrEqual");
284+
nodes.front().set_op_type("LessOrEqual");
285+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
286+
nodes.back().set_op_type("PostConstant");
287287
return nodes;
288288
}
289289

@@ -292,9 +292,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
292292
GreaterConstantOperator<T, D> const op)
293293
{
294294
std::vector<onnx::NodeProto> nodes(2UL);
295-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
296-
nodes.front().set_op_type("PreConstant");
297-
nodes.back().set_op_type("Greater");
295+
nodes.front().set_op_type("Greater");
296+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
297+
nodes.back().set_op_type("PreConstant");
298298
return nodes;
299299
}
300300

@@ -303,9 +303,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
303303
GreaterEqualConstantOperator<T, D> const op)
304304
{
305305
std::vector<onnx::NodeProto> nodes(2UL);
306-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
307-
nodes.front().set_op_type("PreConstant");
308-
nodes.back().set_op_type("GreaterOrEqual");
306+
nodes.front().set_op_type("GreaterOrEqual");
307+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
308+
nodes.back().set_op_type("PreConstant");
309309
return nodes;
310310
}
311311

src/layers/operator_layer.cpp

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -60,57 +60,65 @@ template <typename T, typename O, data_layout L, El::Device D>
6060
void OperatorLayer<T, O, L, D>::fill_onnx_node(
6161
onnx::GraphProto& graph) const
6262
{
63-
std::vector<onnx::NodeProto> nodes(2UL);
64-
nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT);
65-
nodes.front().add_attribute()->set_f(El::To<float>(5));
66-
nodes.front().set_op_type("PostConstant");
67-
nodes.back().set_op_type("Add");
63+
const auto& parents = this->get_parent_layers();
64+
auto nodes = m_ops.front()->get_onnx_nodes();
6865

69-
//OperatorPtr op;
70-
//auto nodes = op->get_onnx_nodes();
71-
const auto* parent = this->get_parent_layers()[0];
66+
auto* op_node = graph.add_node();
67+
*op_node = nodes.front();
7268

73-
auto* const_node = graph.add_node();
74-
*const_node = nodes.front();
69+
op_node->set_name(this->get_name());
70+
op_node->set_domain("");
71+
op_node->set_doc_string(this->get_name());
7572

76-
auto* node = graph.add_node();
77-
*node = nodes.back();
78-
node->set_name(this->get_name());
79-
node->set_domain("");
80-
node->set_doc_string(this->get_name());
81-
if(const_node->op_type() == "PostConstant")
73+
//binary operators
74+
if(nodes.size() == 1)
8275
{
83-
node->add_input(parent->get_name() + "_0");
84-
node->add_input(const_node->output(0));
85-
const_node->set_op_type("Constant");
76+
for(auto* parent : parents)
77+
{
78+
size_t idx = parent->find_child_layer_index(*this);
79+
op_node->add_input(parent->get_name() + "_" + std::to_string(idx));
80+
}
8681
}
87-
else if(const_node->op_type() == "PreConstant")
82+
// Binary w/ constant operators
83+
else if(nodes.size() == 2 || nodes.size() == 3)
8884
{
89-
node->add_input(const_node->output(0));
90-
node->add_input(parent->get_name() + "_0");
85+
auto* const_node = graph.add_node();
86+
*const_node = nodes.back();
87+
if(const_node->op_type() == "PostConstant")
88+
{
89+
op_node->add_input(parents[0]->get_name() + "_0");
90+
op_node->add_input(const_node->output(0));
91+
}
92+
else if(const_node->op_type() == "PreConstant")
93+
{
94+
op_node->add_input(const_node->output(0));
95+
op_node->add_input(parents[0]->get_name() + "_0");
96+
}
97+
else
98+
LBANN_ERROR("Unknown onnx op type for constant.");
99+
91100
const_node->set_op_type("Constant");
92101
}
93102
else
94-
LBANN_ERROR("Unknown onnx op type for constant.");
103+
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size());
95104

96105
// Not equal operator
97106
if(nodes.size() == 3)
98107
{
99-
node->add_output("EqualOperator");
108+
op_node->add_output("EqualOperator");
100109
auto* not_node = graph.add_node();
101-
not_node->add_input(node->output(0));
102-
not_node->add_output(this->get_child_layers()[0]->get_name() + "_0");
110+
not_node->add_input(op_node->output(0));
103111
not_node->set_name("Not operator");
104112
not_node->set_op_type("Not");
105113
not_node->set_domain("");
106114
not_node->set_doc_string("Not node for not equal operation.");
115+
op_node = not_node;
107116
}
108-
else if(nodes.size() == 2)
109-
{
110-
node->add_output(this->get_child_layers()[0]->get_name() + "_0");
117+
118+
for (auto const* child : this->get_child_layers()) {
119+
auto idx = this->find_child_layer_index(*child);
120+
op_node->add_output(this->get_name() + "_" + std::to_string(idx));
111121
}
112-
else
113-
LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size());
114122
}
115123
#endif // LBANN_HAS_ONNX
116124

0 commit comments

Comments
 (0)