Skip to content

Commit 7e9349a

Browse files
committed
gtests: graph: unit: add gtests for verifying commutative select
1 parent 57ca202 commit 7e9349a

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,46 @@ TEST(test_utils_pattern_matcher, OptionalWithLargerPort) {
11081108
ASSERT_EQ(fusion_ops.size(), 3U);
11091109
}
11101110

1111+
TEST(test_utils_pattern_matcher, CommutativeSelect) {
1112+
auto graphp = std::make_shared<pb_graph_t>();
1113+
// Pattern that captures
1114+
// MatMul Add
1115+
// \ /
1116+
// Select
1117+
auto pmatmul = graphp->append_op(MatMul);
1118+
auto aadd = graphp->append_op(Add);
1119+
auto pselect = graphp->append_op(
1120+
Select, {in_edge(1, pmatmul, 0), in_edge(2, aadd, 0)});
1121+
UNUSED(pselect);
1122+
1123+
graph_t agraph;
1124+
op_t matmul {0, MatMul, "matmul"};
1125+
op_t add {1, Add, "Add"};
1126+
op_t select {2, Select, "select"};
1127+
1128+
std::vector<logical_tensor_t> lt_vec = create_logical_tensors(8);
1129+
lt_vec[6].data_type = data_type::boolean;
1130+
matmul.add_input(lt_vec[0]);
1131+
matmul.add_input(lt_vec[1]);
1132+
matmul.add_output(lt_vec[2]);
1133+
add.add_input(lt_vec[3]);
1134+
add.add_input(lt_vec[4]);
1135+
add.add_output(lt_vec[5]);
1136+
select.add_input(lt_vec[6]);
1137+
select.add_input(lt_vec[5]);
1138+
select.add_input(lt_vec[2]);
1139+
select.add_output(lt_vec[7]);
1140+
1141+
ASSERT_EQ(agraph.add_op(&matmul), status::success);
1142+
ASSERT_EQ(agraph.add_op(&add), status::success);
1143+
ASSERT_EQ(agraph.add_op(&select), status::success);
1144+
agraph.finalize();
1145+
1146+
std::vector<op_t *> fusion_ops;
1147+
EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops));
1148+
ASSERT_EQ(fusion_ops.size(), 3U);
1149+
}
1150+
11111151
//
11121152
// ?: means optional
11131153
// ^: means repetition

0 commit comments

Comments
 (0)