Skip to content

Commit 57ca202

Browse files
committed
graph: utils: pm: support non-binary commutative op
1 parent b579a2e commit 57ca202

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/graph/utils/pm/nested_matcher.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ namespace pm {
7070

7171
namespace {
7272
// check if an op's inputs are commutative
73-
bool has_commutative_inputs(op_t *op) {
73+
bool is_commutative_input(op_t *op, const std::set<size_t> &inputs) {
7474
const op_schema_t *opm
7575
= op_schema_registry_t::get_op_schema(op->get_kind());
76-
return opm->get_commutative_inputs();
76+
return opm->get_commutative_inputs(inputs);
7777
}
7878

7979
// fill local context in map when optional exists
@@ -200,12 +200,16 @@ bool node_inputs_matcher_t::match_commutative_inputs() {
200200

201201
for (size_t node_input_offset = 0; node_input_offset < node_inputs_.size();
202202
++node_input_offset) {
203+
if (!is_commutative_input(
204+
get_op(), {node_inputs_[node_input_offset].first}))
205+
continue;
203206
if (verified_node_input_ports.find(
204207
node_inputs_[node_input_offset].first)
205208
!= verified_node_input_ports.end())
206209
continue;
207210
for (size_t op_input_offset = 0; op_input_offset < op_->num_inputs();
208211
++op_input_offset) {
212+
if (!is_commutative_input(get_op(), {op_input_offset})) continue;
209213
if (verified_op_input_ports.find(op_input_offset)
210214
== verified_op_input_ports.end()
211215
&& match_input_by_offset(
@@ -252,7 +256,8 @@ bool match_node_inputs(const binding_t &b, match_context_t *ctx,
252256
if (node_inputs_matcher.get_node()->get_inputs().size()
253257
== VARIADIC_INPUT_NUM) {
254258
matching_status = node_inputs_matcher.match_variadic_inputs();
255-
} else if (!has_commutative_inputs(node_inputs_matcher.get_op())) {
259+
} else if (!is_commutative_input(node_inputs_matcher.get_op(),
260+
{b.bind_op_port, b.bind_port})) {
256261
matching_status = node_inputs_matcher.match_non_commutative_inputs();
257262
} else {
258263
matching_status = node_inputs_matcher.match_commutative_inputs();
@@ -552,7 +557,9 @@ bool match_node(const binding_t &b, match_context_t *ctx,
552557
__FILE__, __LINE__);
553558
return false;
554559
}
555-
if (!has_commutative_inputs(b.bind_op) && b.bind_op_port != b.bind_port) {
560+
if (b.bind_op_port != b.bind_port
561+
&& !is_commutative_input(
562+
b.bind_op, {b.bind_op_port, b.bind_port})) {
556563
DEBUG(DEBUGINFO_PM,
557564
"matching op & node: %s (%s) <=> %s, matching "
558565
"failed \n",

0 commit comments

Comments
 (0)