@@ -123,8 +123,8 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
123123 const auto & pattern_map = m.get_pattern_value_map ();
124124 auto root = m.get_match_root ();
125125
126- // Check VariadicSplit output connections in combined mode
127- bool gate_up_swapped = false ;
126+ // Determine gate_up_type based on pattern matching
127+ LLMMLPNode::GATE_UP_TYPE gate_up_type = LLMMLPNode::GATE_UP_TYPE::SEPARATE ;
128128 if (pattern_map.count (gate_up_proj_split)) {
129129 auto mlp_gated_up_node = pattern_map.at (mlp_gated_up).get_node_shared_ptr ();
130130 auto input0 = mlp_gated_up_node->input_value (0 );
@@ -134,9 +134,10 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
134134 // Since pattern matching succeeded, we know one of the outputs connects to Multiply
135135 if ((input0.get_node () == pattern_map.at (gate_up_proj_split).get_node () && input0.get_index () == 0 ) ||
136136 (input1.get_node () == pattern_map.at (gate_up_proj_split).get_node () && input1.get_index () == 0 )) {
137- gate_up_swapped = true ;
137+ gate_up_type = LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE; // swapped case
138+ } else {
139+ gate_up_type = LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP; // normal combined case
138140 }
139- // Otherwise, it's the normal case where output[1] connects to Multiply
140141 }
141142
142143 auto src = pattern_map.at (input);
@@ -151,17 +152,20 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
151152 // down projection is harder to quantize w/o causing accuracy problem, so it may be un-quantized instead
152153 bool is_gate_up_quantized_int8 = false ;
153154 bool is_down_proj_int8 = false ;
154- bool is_gate_up_combined = false ;
155155 if (pattern_map.count (gate_up_proj_weight_const_i8) > 0 && pattern_map.count (down_proj_weight_compressed) > 0 ) {
156156 // gate-up combined & quantized
157157 is_gate_up_quantized_int8 = true ;
158- is_gate_up_combined = true ;
158+ gate_up_type = (gate_up_type == LLMMLPNode::GATE_UP_TYPE::SEPARATE)
159+ ? LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP
160+ : gate_up_type;
159161 gate_proj_w = pattern_map.at (gate_up_proj_weight_const_i8);
160162 up_proj_w = pattern_map.at (gate_up_proj_weight_const_i8);
161163 down_proj_w = pattern_map.at (down_proj_weight_compressed);
162164 } else if (pattern_map.count (gate_up_proj_weight) > 0 && pattern_map.count (down_proj_weight_compressed) > 0 ) {
163165 // gate-up combined
164- is_gate_up_combined = true ;
166+ gate_up_type = (gate_up_type == LLMMLPNode::GATE_UP_TYPE::SEPARATE)
167+ ? LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP
168+ : gate_up_type;
165169 gate_proj_w = pattern_map.at (gate_up_proj_weight);
166170 up_proj_w = pattern_map.at (gate_up_proj_weight);
167171 down_proj_w = pattern_map.at (down_proj_weight_compressed);
@@ -224,7 +228,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
224228 return false ;
225229 }
226230
227- auto up_size = is_gate_up_combined ? (up_shape[0 ] / 2 ) : (up_shape[0 ]);
231+ auto up_size = (gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) ? (up_shape[0 ] / 2 ) : (up_shape[0 ]);
228232 auto down_size = up_shape[1 ];
229233 if (down_shape[0 ] != down_size) {
230234 return false ;
@@ -240,8 +244,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
240244 cfg.down_quantized = is_down_proj_int8;
241245 cfg.hidden_size = down_size;
242246 cfg.up_size = up_size;
243- cfg.gate_up_combined = is_gate_up_combined;
244- cfg.gate_up_swapped = gate_up_swapped;
247+ cfg.gate_up_type = gate_up_type;
245248
246249 if (pattern_map.count (mlp_silu_gate) > 0 ) {
247250 cfg.act = LLMMLPNode::ACT_FN::SILU;
@@ -266,7 +269,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
266269 new_args.push_back (up_proj_w);
267270 new_args.push_back (down_proj_w);
268271 if (is_gate_up_quantized_int8) {
269- if (is_gate_up_combined ) {
272+ if (gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE ) {
270273 new_args.push_back (pattern_map.at (gate_up_proj_weight_scales_per_OC));
271274 new_args.push_back (pattern_map.at (gate_up_proj_weight_scales_per_OC));
272275 } else {
@@ -284,7 +287,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
284287 ov::copy_runtime_info (
285288 {pattern_map.at (gate_act).get_node_shared_ptr (), pattern_map.at (down_proj).get_node_shared_ptr ()},
286289 new_node);
287- if (is_gate_up_combined ) {
290+ if (gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE ) {
288291 ov::copy_runtime_info ({pattern_map.at (gate_up_proj).get_node_shared_ptr ()}, new_node);
289292 } else {
290293 ov::copy_runtime_info ({pattern_map.at (mlp_gate_proj).get_node_shared_ptr (),
0 commit comments