Skip to content

Commit b7ff7b3

Browse files
committed
Fix Alibaba-NLP accuracy issue: avoid fusing MLP with unmatched VariadicSplit output ports
1 parent c733cd0 commit b7ff7b3

File tree

1 file changed

+27
-0
lines changed
  • src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass

1 file changed

+27
-0
lines changed

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,33 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
122122
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
123123
const auto& pattern_map = m.get_pattern_value_map();
124124
auto root = m.get_match_root();
125+
// Check that the first input of Multiply is the gate (activation) branch and the second input is the up branch;
126+
// otherwise, do not fuse.
127+
auto mlp_gated_up_node = pattern_map.at(mlp_gated_up).get_node_shared_ptr();
128+
auto input0 = mlp_gated_up_node->input_value(0);
129+
auto input1 = mlp_gated_up_node->input_value(1);
130+
131+
bool input0_is_gate = false;
132+
bool input1_is_up = false;
133+
134+
if (pattern_map.count(mlp_silu_gate) && input0.get_node() == pattern_map.at(mlp_silu_gate).get_node()) {
135+
input0_is_gate = true;
136+
}
137+
if (pattern_map.count(mlp_gelu_gate) && input0.get_node() == pattern_map.at(mlp_gelu_gate).get_node()) {
138+
input0_is_gate = true;
139+
}
140+
141+
if (pattern_map.count(mlp_up_proj) && input1.get_node() == pattern_map.at(mlp_up_proj).get_node()) {
142+
input1_is_up = true;
143+
}
144+
if (pattern_map.count(gate_up_proj_split) &&
145+
input1.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input1.get_index() == 1) {
146+
input1_is_up = true;
147+
}
148+
149+
if (!input0_is_gate || !input1_is_up) {
150+
return false;
151+
}
125152
auto src = pattern_map.at(input);
126153
if (!src.get_element_type().is_real()) {
127154
// FakeQuantize, should skip fusion

0 commit comments

Comments
 (0)