Skip to content

Commit 7b923c8

Browse files
committed
add subgraph tests for un-fusing behavior
1 parent 9ba3634 commit 7b923c8

File tree

1 file changed

+40
-13
lines changed
  • src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64

1 file changed

+40
-13
lines changed

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
#include <vector>
77

88
#include "common_test_utils/ov_tensor_utils.hpp"
9-
#include "openvino/runtime/exec_model_info.hpp"
10-
#include "shared_test_classes/base/ov_subgraph.hpp"
119
#include "openvino/op/convert.hpp"
1210
#include "openvino/op/gelu.hpp"
1311
#include "openvino/op/matmul.hpp"
1412
#include "openvino/op/multiply.hpp"
1513
#include "openvino/op/swish.hpp"
14+
#include "openvino/runtime/exec_model_info.hpp"
15+
#include "shared_test_classes/base/ov_subgraph.hpp"
1616

1717
namespace ov {
1818
namespace test {
@@ -23,6 +23,7 @@ struct LLMMLPFusionParams {
2323
size_t up_size;
2424
std::string act_type;
2525
bool use_dynamic_quant;
26+
bool swap_inputs; // true = swap inputs to prevent fusion, false = normal order for fusion
2627
};
2728

2829
class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>, public ov::test::SubgraphBaseTest {
@@ -39,6 +40,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
3940
result << "up_size=" << obj.param.up_size << "_";
4041
result << "act_type=" << obj.param.act_type << "_";
4142
result << "use_dynamic_quant=" << obj.param.use_dynamic_quant << "_";
43+
result << "swap_inputs=" << obj.param.swap_inputs << "_";
4244
result << obj.index;
4345
return result.str();
4446
}
@@ -70,7 +72,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
7072
in_data.start_from = 0;
7173
in_data.range = 1;
7274
in_data.resolution = 128;
73-
auto tensor_scale_per_oc = ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data);
75+
auto tensor_scale_per_oc =
76+
ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data);
7477
auto scale_per_oc = std::make_shared<ov::op::v0::Constant>(tensor_scale_per_oc);
7578

7679
auto weight_deq = std::make_shared<ov::op::v1::Multiply>(weight_const_f32, scale_per_oc);
@@ -85,7 +88,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
8588
return std::make_shared<ov::op::v0::Constant>(tensor);
8689
};
8790
if (param.use_dynamic_quant)
88-
configuration.insert({ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits<uint64_t>::max()});
91+
configuration.insert(
92+
{ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits<uint64_t>::max()});
8993

9094
auto gate_weight = create_const(param.up_size, param.down_size, 100);
9195
auto up_weight = create_const(param.up_size, param.down_size, 100);
@@ -101,13 +105,22 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
101105
if (param.act_type == "Gelu")
102106
gate_act = std::make_shared<ov::op::v7::Gelu>(gate_proj);
103107

104-
auto gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_proj);
108+
// Control input order based on swap_inputs parameter
109+
std::shared_ptr<ov::op::v1::Multiply> gate_up;
110+
if (param.swap_inputs) {
111+
// Swapped order should prevent fusion
112+
gate_up = std::make_shared<ov::op::v1::Multiply>(up_proj, gate_act);
113+
} else {
114+
// Normal order should allow fusion
115+
gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_proj);
116+
}
117+
105118
auto output = std::make_shared<ov::op::v0::MatMul>(gate_up, down_weight, false, true);
106119

107120
function = std::make_shared<ov::Model>(ov::OutputVector{output}, ov::ParameterVector{src});
108121
}
109122

110-
void check_results() {
123+
void check_fusion_result() {
111124
auto exec_model = compiledModel.get_runtime_model();
112125

113126
int fused_node_found = 0;
@@ -116,26 +129,40 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
116129
if (layer_type == "LLMMLP")
117130
fused_node_found++;
118131
}
119-
ASSERT_EQ(fused_node_found, 1);
132+
133+
auto& param = this->GetParam();
134+
if (param.swap_inputs) {
135+
// When inputs are swapped, fusion should NOT happen
136+
ASSERT_EQ(fused_node_found, 0) << "Fusion should not occur with swapped inputs";
137+
} else {
138+
// Normal case, fusion should happen
139+
ASSERT_EQ(fused_node_found, 1) << "Fusion should occur with correct input order";
140+
}
120141
}
121142
};
122143

123144
TEST_P(LLMMLPFusionTest, CompareWithRefs) {
124145
if (!ov::with_cpu_x86_avx512_core_amx_bf16())
125146
GTEST_SKIP();
126147
run();
127-
check_results();
148+
check_fusion_result();
128149
}
129150

130151
namespace {
131152

132-
static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4}, {ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}};
153+
static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4},
154+
{ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}};
133155

156+
// Test parameters combining both normal fusion and no-fusion cases
134157
const std::vector<LLMMLPFusionParams> mlp_params = {
135-
{ishape, 4096 / 4, 11008 / 4, "Gelu", false},
136-
{ishape, 4096 / 4, 11008 / 4, "Gelu", true},
137-
{ishape, 4096 / 4, 11008 / 4, "Swish", false},
138-
{ishape, 4096 / 4, 11008 / 4, "Swish", true},
158+
// Normal cases - should fuse (swap_inputs = false)
159+
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, false},
160+
{ishape, 4096 / 4, 11008 / 4, "Gelu", true, false},
161+
{ishape, 4096 / 4, 11008 / 4, "Swish", false, false},
162+
{ishape, 4096 / 4, 11008 / 4, "Swish", true, false},
163+
164+
// Port order issue cases - should NOT fuse (swap_inputs = true)
165+
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, true},
139166
};
140167

141168
INSTANTIATE_TEST_SUITE_P(smoke_LLMMLPFusion,

0 commit comments

Comments
 (0)