66#include < vector>
77
88#include " common_test_utils/ov_tensor_utils.hpp"
9- #include " openvino/runtime/exec_model_info.hpp"
10- #include " shared_test_classes/base/ov_subgraph.hpp"
119#include " openvino/op/convert.hpp"
1210#include " openvino/op/gelu.hpp"
1311#include " openvino/op/matmul.hpp"
1412#include " openvino/op/multiply.hpp"
1513#include " openvino/op/swish.hpp"
14+ #include " openvino/op/variadic_split.hpp"
15+ #include " openvino/runtime/exec_model_info.hpp"
16+ #include " shared_test_classes/base/ov_subgraph.hpp"
17+ #include " transformations/rt_info/decompression.hpp"
1618
1719namespace ov {
1820namespace test {
@@ -23,6 +25,7 @@ struct LLMMLPFusionParams {
2325 size_t up_size;
2426 std::string act_type;
2527 bool use_dynamic_quant;
28+ bool use_swapped_outputs; // true = create pattern with swapped VariadicSplit outputs (should still fuse)
2629};
2730
2831class LLMMLPFusionTest : public testing ::WithParamInterface<LLMMLPFusionParams>, public ov::test::SubgraphBaseTest {
@@ -39,6 +42,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
3942 result << " up_size=" << obj.param .up_size << " _" ;
4043 result << " act_type=" << obj.param .act_type << " _" ;
4144 result << " use_dynamic_quant=" << obj.param .use_dynamic_quant << " _" ;
45+ result << " use_swapped_outputs=" << obj.param .use_swapped_outputs << " _" ;
4246 result << obj.index ;
4347 return result.str ();
4448 }
@@ -70,7 +74,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
7074 in_data.start_from = 0 ;
7175 in_data.range = 1 ;
7276 in_data.resolution = 128 ;
73- auto tensor_scale_per_oc = ov::test::utils::create_and_fill_tensor (ov::element::f32 , ov::Shape{OC, 1 }, in_data);
77+ auto tensor_scale_per_oc =
78+ ov::test::utils::create_and_fill_tensor (ov::element::f32 , ov::Shape{OC, 1 }, in_data);
7479 auto scale_per_oc = std::make_shared<ov::op::v0::Constant>(tensor_scale_per_oc);
7580
7681 auto weight_deq = std::make_shared<ov::op::v1::Multiply>(weight_const_f32, scale_per_oc);
@@ -85,38 +90,89 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
8590 return std::make_shared<ov::op::v0::Constant>(tensor);
8691 };
8792 if (param.use_dynamic_quant )
88- configuration.insert ({ov::hint::dynamic_quantization_group_size.name (), std::numeric_limits<uint64_t >::max ()});
89-
90- auto gate_weight = create_const (param.up_size , param.down_size , 100 );
91- auto up_weight = create_const (param.up_size , param.down_size , 100 );
92- // down_proj has special cache blocking along K dimension requires lower weight resolution
93- auto down_weight = create_const (param.down_size , param.up_size , 16 );
94-
95- auto gate_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_weight, false , true );
96- auto up_proj = std::make_shared<ov::op::v0::MatMul>(src, up_weight, false , true );
93+ configuration.insert (
94+ {ov::hint::dynamic_quantization_group_size.name (), std::numeric_limits<uint64_t >::max ()});
9795
9896 std::shared_ptr<Node> gate_act;
99- if (param.act_type == " Swish" )
100- gate_act = std::make_shared<ov::op::v4::Swish>(gate_proj);
101- if (param.act_type == " Gelu" )
102- gate_act = std::make_shared<ov::op::v7::Gelu>(gate_proj);
97+ ov::Output<ov::Node> up_output;
10398
104- auto gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_proj);
99+ if (param.use_swapped_outputs ) {
100+ // Create pattern with swapped VariadicSplit outputs to test gate_up_swapped support
101+ ov::test::utils::InputGenerateData in_data;
102+ in_data.start_from = -0.5 ;
103+ in_data.range = 1.0 ;
104+ in_data.resolution = 16 ;
105+
106+ // Combined gate_up weight in FP16 format
107+ auto tensor_f16 = ov::test::utils::create_and_fill_tensor (ov::element::f16 ,
108+ ov::Shape{param.up_size * 2 , param.down_size },
109+ in_data);
110+ auto gate_up_weight_f16 = std::make_shared<ov::op::v0::Constant>(tensor_f16);
111+ auto gate_up_weight_f32 = std::make_shared<ov::op::v0::Convert>(gate_up_weight_f16, ov::element::f32 );
112+ mark_as_decompression (gate_up_weight_f32);
113+
114+ auto gate_up_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_up_weight_f32, false , true );
115+
116+ auto split_lengths = std::make_shared<ov::op::v0::Constant>(
117+ ov::element::i32 ,
118+ ov::Shape{2 },
119+ std::vector<int32_t >{static_cast <int32_t >(param.up_size ), static_cast <int32_t >(param.up_size )});
120+ auto axis_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32 , ov::Shape{}, -1 );
121+ auto gate_up_split = std::make_shared<ov::op::v1::VariadicSplit>(gate_up_proj, axis_const, split_lengths);
122+
123+ // Swap outputs to test gate_up_swapped support
124+ auto gate_part = gate_up_split->output (1 ); // activation on output[1]
125+ if (param.act_type == " Swish" )
126+ gate_act = std::make_shared<ov::op::v4::Swish>(gate_part);
127+ if (param.act_type == " Gelu" )
128+ gate_act = std::make_shared<ov::op::v7::Gelu>(gate_part);
129+
130+ auto up_part = gate_up_split->output (0 ); // up branch from output[0] (swapped case)
131+ up_output = up_part;
132+ } else {
133+ // Standard separate weights pattern
134+ auto gate_weight = create_const (param.up_size , param.down_size , 100 );
135+ auto up_weight = create_const (param.up_size , param.down_size , 100 );
136+
137+ auto gate_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_weight, false , true );
138+ auto up_proj = std::make_shared<ov::op::v0::MatMul>(src, up_weight, false , true );
139+
140+ if (param.act_type == " Swish" )
141+ gate_act = std::make_shared<ov::op::v4::Swish>(gate_proj);
142+ if (param.act_type == " Gelu" )
143+ gate_act = std::make_shared<ov::op::v7::Gelu>(gate_proj);
144+
145+ up_output = up_proj;
146+ }
147+
148+ // Create compressed down projection weight
149+ ov::test::utils::InputGenerateData down_data;
150+ down_data.start_from = -0.5 ;
151+ down_data.range = 1 ;
152+ down_data.resolution = 16 ;
153+ auto tensor_f16_down = ov::test::utils::create_and_fill_tensor (ov::element::f16 ,
154+ ov::Shape{param.down_size , param.up_size },
155+ down_data);
156+ auto down_weight_f16 = std::make_shared<ov::op::v0::Constant>(tensor_f16_down);
157+ auto down_weight = std::make_shared<ov::op::v0::Convert>(down_weight_f16, ov::element::f32 );
158+
159+ auto gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_output);
105160 auto output = std::make_shared<ov::op::v0::MatMul>(gate_up, down_weight, false , true );
106161
107162 function = std::make_shared<ov::Model>(ov::OutputVector{output}, ov::ParameterVector{src});
108163 }
109164
110165 void check_results () {
111166 auto exec_model = compiledModel.get_runtime_model ();
112-
113167 int fused_node_found = 0 ;
114168 for (const auto & n : exec_model->get_ordered_ops ()) {
115169 auto layer_type = n->get_rt_info ().at (ov::exec_model_info::LAYER_TYPE).as <std::string>();
116170 if (layer_type == " LLMMLP" )
117171 fused_node_found++;
118172 }
119- ASSERT_EQ (fused_node_found, 1 );
173+
174+ // Both normal and swapped cases should fuse successfully
175+ ASSERT_EQ (fused_node_found, 1 ) << " Fusion should occur with valid MLP patterns (both normal and swapped cases)" ;
120176 }
121177};
122178
@@ -129,13 +185,18 @@ TEST_P(LLMMLPFusionTest, CompareWithRefs) {
129185
130186namespace {
131187
132- static ov::test::InputShape ishape{ov::PartialShape{-1 , -1 , 4096 / 4 }, {ov::Shape{1 , 8 , 4096 / 4 }, ov::Shape{5 , 37 , 4096 / 4 }}};
188+ static ov::test::InputShape ishape{ov::PartialShape{-1 , -1 , 4096 / 4 },
189+ {ov::Shape{1 , 8 , 4096 / 4 }, ov::Shape{5 , 37 , 4096 / 4 }}};
133190
134191const std::vector<LLMMLPFusionParams> mlp_params = {
135- {ishape, 4096 / 4 , 11008 / 4 , " Gelu" , false },
136- {ishape, 4096 / 4 , 11008 / 4 , " Gelu" , true },
137- {ishape, 4096 / 4 , 11008 / 4 , " Swish" , false },
138- {ishape, 4096 / 4 , 11008 / 4 , " Swish" , true },
192+ // Standard separate weights cases (should all fuse successfully)
193+ {ishape, 4096 / 4 , 11008 / 4 , " Gelu" , false , false },
194+ {ishape, 4096 / 4 , 11008 / 4 , " Gelu" , true , false },
195+ {ishape, 4096 / 4 , 11008 / 4 , " Swish" , false , false },
196+ {ishape, 4096 / 4 , 11008 / 4 , " Swish" , true , false },
197+
198+ // Test case with swapped VariadicSplit outputs (should fuse with gate_up_swapped=true)
199+ {ishape, 4096 / 4 , 11008 / 4 , " Gelu" , false , true },
139200};
140201
141202INSTANTIATE_TEST_SUITE_P (smoke_LLMMLPFusion,
0 commit comments