@@ -70,10 +70,10 @@ namespace pm {
70
70
71
71
namespace {
72
72
// check if an op's inputs are commutative
73
- bool has_commutative_inputs (op_t *op) {
73
+ bool is_commutative_input (op_t *op, const std::set< size_t > &inputs ) {
74
74
const op_schema_t *opm
75
75
= op_schema_registry_t::get_op_schema (op->get_kind ());
76
- return opm->get_commutative_inputs ();
76
+ return opm->get_commutative_inputs (inputs );
77
77
}
78
78
79
79
// fill local context in map when optional exists
@@ -200,12 +200,16 @@ bool node_inputs_matcher_t::match_commutative_inputs() {
200
200
201
201
for (size_t node_input_offset = 0 ; node_input_offset < node_inputs_.size ();
202
202
++node_input_offset) {
203
+ if (!is_commutative_input (
204
+ get_op (), {node_inputs_[node_input_offset].first }))
205
+ continue ;
203
206
if (verified_node_input_ports.find (
204
207
node_inputs_[node_input_offset].first )
205
208
!= verified_node_input_ports.end ())
206
209
continue ;
207
210
for (size_t op_input_offset = 0 ; op_input_offset < op_->num_inputs ();
208
211
++op_input_offset) {
212
+ if (!is_commutative_input (get_op (), {op_input_offset})) continue ;
209
213
if (verified_op_input_ports.find (op_input_offset)
210
214
== verified_op_input_ports.end ()
211
215
&& match_input_by_offset (
@@ -252,7 +256,8 @@ bool match_node_inputs(const binding_t &b, match_context_t *ctx,
252
256
if (node_inputs_matcher.get_node ()->get_inputs ().size ()
253
257
== VARIADIC_INPUT_NUM) {
254
258
matching_status = node_inputs_matcher.match_variadic_inputs ();
255
- } else if (!has_commutative_inputs (node_inputs_matcher.get_op ())) {
259
+ } else if (!is_commutative_input (node_inputs_matcher.get_op (),
260
+ {b.bind_op_port , b.bind_port })) {
256
261
matching_status = node_inputs_matcher.match_non_commutative_inputs ();
257
262
} else {
258
263
matching_status = node_inputs_matcher.match_commutative_inputs ();
@@ -552,7 +557,9 @@ bool match_node(const binding_t &b, match_context_t *ctx,
552
557
__FILE__, __LINE__);
553
558
return false ;
554
559
}
555
- if (!has_commutative_inputs (b.bind_op ) && b.bind_op_port != b.bind_port ) {
560
+ if (b.bind_op_port != b.bind_port
561
+ && !is_commutative_input (
562
+ b.bind_op , {b.bind_op_port , b.bind_port })) {
556
563
DEBUG (DEBUGINFO_PM,
557
564
" matching op & node: %s (%s) <=> %s, matching "
558
565
" failed \n " ,
0 commit comments