File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed
src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -122,6 +122,33 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
122122 matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
123123 const auto & pattern_map = m.get_pattern_value_map ();
124124 auto root = m.get_match_root ();
125+ // Check that the first input of Multiply is the gate (activation) branch and the second input is the up branch;
126+ // otherwise, do not fuse.
127+ auto mlp_gated_up_node = pattern_map.at (mlp_gated_up).get_node_shared_ptr ();
128+ auto input0 = mlp_gated_up_node->input_value (0 );
129+ auto input1 = mlp_gated_up_node->input_value (1 );
130+
131+ bool input0_is_gate = false ;
132+ bool input1_is_up = false ;
133+
134+ if (pattern_map.count (mlp_silu_gate) && input0.get_node () == pattern_map.at (mlp_silu_gate).get_node ()) {
135+ input0_is_gate = true ;
136+ }
137+ if (pattern_map.count (mlp_gelu_gate) && input0.get_node () == pattern_map.at (mlp_gelu_gate).get_node ()) {
138+ input0_is_gate = true ;
139+ }
140+
141+ if (pattern_map.count (mlp_up_proj) && input1.get_node () == pattern_map.at (mlp_up_proj).get_node ()) {
142+ input1_is_up = true ;
143+ }
144+ if (pattern_map.count (gate_up_proj_split) &&
145+ input1.get_node () == pattern_map.at (gate_up_proj_split).get_node () && input1.get_index () == 1 ) {
146+ input1_is_up = true ;
147+ }
148+
149+ if (!input0_is_gate || !input1_is_up) {
150+ return false ;
151+ }
125152 auto src = pattern_map.at (input);
126153 if (!src.get_element_type ().is_real ()) {
127154 // FakeQuantize, should skip fusion
You can’t perform that action at this time.
0 commit comments