@@ -1108,6 +1108,46 @@ TEST(test_utils_pattern_matcher, OptionalWithLargerPort) {
1108
1108
ASSERT_EQ (fusion_ops.size (), 3U );
1109
1109
}
1110
1110
1111
+ TEST (test_utils_pattern_matcher, CommutativeSelect) {
1112
+ auto graphp = std::make_shared<pb_graph_t >();
1113
+ // Pattern that captures
1114
+ // MatMul Add
1115
+ // \ /
1116
+ // Select
1117
+ auto pmatmul = graphp->append_op (MatMul);
1118
+ auto aadd = graphp->append_op (Add);
1119
+ auto pselect = graphp->append_op (
1120
+ Select, {in_edge (1 , pmatmul, 0 ), in_edge (2 , aadd, 0 )});
1121
+ UNUSED (pselect );
1122
+
1123
+ graph_t agraph;
1124
+ op_t matmul {0 , MatMul, " matmul" };
1125
+ op_t add {1 , Add, " Add" };
1126
+ op_t select {2 , Select, " select" };
1127
+
1128
+ std::vector<logical_tensor_t > lt_vec = create_logical_tensors (8 );
1129
+ lt_vec[6 ].data_type = data_type::boolean;
1130
+ matmul.add_input (lt_vec[0 ]);
1131
+ matmul.add_input (lt_vec[1 ]);
1132
+ matmul.add_output (lt_vec[2 ]);
1133
+ add.add_input (lt_vec[3 ]);
1134
+ add.add_input (lt_vec[4 ]);
1135
+ add.add_output (lt_vec[5 ]);
1136
+ select .add_input (lt_vec[6 ]);
1137
+ select .add_input (lt_vec[5 ]);
1138
+ select .add_input (lt_vec[2 ]);
1139
+ select .add_output (lt_vec[7 ]);
1140
+
1141
+ ASSERT_EQ (agraph.add_op (&matmul), status::success);
1142
+ ASSERT_EQ (agraph.add_op (&add), status::success);
1143
+ ASSERT_EQ (agraph.add_op (&select ), status::success);
1144
+ agraph.finalize ();
1145
+
1146
+ std::vector<op_t *> fusion_ops;
1147
+ EXPECT_TRUE (match_pattern (agraph.get_ops ()[0 ].get (), graphp, fusion_ops));
1148
+ ASSERT_EQ (fusion_ops.size (), 3U );
1149
+ }
1150
+
1111
1151
//
1112
1152
// ?: means optional
1113
1153
// ^: means repetition
0 commit comments