@@ -20,8 +20,8 @@ namespace npuw_utest{
2020}
2121
2222enum class NetworkKind {
23- llama2,
24- llama3
23+ MHA, // Multi-Head Attention (e.g., llama2) - no broadcast
24+ GQA // Grouped Query Attention (e.g., llama3, phi3, mistral, GPT-OSS) - with broadcast
2525};
2626
2727typedef std::tuple <
@@ -30,6 +30,7 @@ typedef std::tuple <
3030 bool , // withTranspose - without transpose node - matcher shouldnt detect subgraph, easy way to negative case
3131 bool , // withSDPA - should SDPA layer present or be already unrolled or simplified
3232 bool , // use high precision on attention_mask input
33+ bool , // withSink - SDPA with 6th input (sink) for GPT-OSS pattern
3334 NetworkKind
3435> OptimizeVTTestParamsTuple;
3536
@@ -41,11 +42,12 @@ struct OptimizeVTTestParams {
4142 _AT (2 ) withTranspose;
4243 _AT (3 ) withSDPA;
4344 _AT (4 ) withHpAttenMask;
44- _AT (5 ) kind;
45+ _AT (5 ) withSink;
46+ _AT (6 ) kind;
4547 #undef _AT
4648
4749 OptimizeVTTestParams (const OptimizeVTTestParamsTuple& tup) {
48- std::tie (inputShape, withConvert, withTranspose, withSDPA, withHpAttenMask, kind) = tup;
50+ std::tie (inputShape, withConvert, withTranspose, withSDPA, withHpAttenMask, withSink, kind) = tup;
4951 }
5052};
5153
@@ -75,8 +77,9 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
7577
7678
7779 // validation of High Precision attention mask - implies enabling SDPA layer to be unrolled,
78- // and also specific FP16 activation transformation in partitioning
79- if (test.withSDPA ) {
80+ // and also specific FP16 activation transformation in partitioning
81+ // Note: When withSink=true, standard OpenVINO SDPA decomposition is used which doesn't support HP
82+ if (test.withSDPA && !test.withSink ) {
8083 std::shared_ptr<::intel_npu::OptionsDesc> options_desc;
8184
8285 auto opt_desc = std::make_shared<::intel_npu::OptionsDesc>();
@@ -187,9 +190,10 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
187190
188191 std::ostringstream result;
189192 result << " npuw_llm_pipeline_" << test.inputShape << " _"
190- << (test.kind == NetworkKind::llama3 ? " LLAMA3 " : " LLAMA2 " )
193+ << (test.kind == NetworkKind::MHA ? " MHA " : " GQA " )
191194 << (test.withConvert ? " _with_convert" : " " )
192195 << (test.withSDPA ? " _SDPA" : " " )
196+ << (test.withSink ? " _Sink" : " " )
193197 << (test.withHpAttenMask ? " _HP" : " " )
194198 << (!test.withTranspose ? " _NEGATIVE" : " " );
195199 return result.str ();
@@ -212,7 +216,7 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
212216 };
213217
214218 // in case of non broadcast number of input channels significantly smaller
215- auto numChannels = (test.kind == NetworkKind::llama3 ) ? 8 : 32 ;
219+ auto numChannels = (test.kind == NetworkKind::MHA ) ? 32 : 8 ;
216220 auto input_shape = test.inputShape ;
217221 auto input_2 = static_cast <int >(test.inputShape [2 ]);
218222 auto input_3 = static_cast <int >(test.inputShape [3 ]);
@@ -256,7 +260,7 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
256260
257261 std::shared_ptr<ov::Node> concat_or_reshape = concat;
258262
259- if (test.kind == NetworkKind::llama3 ) {
263+ if (test.kind == NetworkKind::GQA ) {
260264 auto unsqueeze_pattern = create_shape_constant ({2 }, " unsqueese_pattern" );
261265 auto unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(concat, unsqueeze_pattern);
262266 unsqueeze->set_friendly_name (" unsqueeze" );
@@ -285,25 +289,32 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
285289 auto k_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::Shape{1 , 32 , input_shape[2 ] + 1 , input_shape[3 ]});
286290 k_input->set_friendly_name (" k_input" );
287291
288-
289292 auto q_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::Shape{1 , 32 , input_shape[2 ] + 1 , input_shape[3 ]});
290293 q_input->set_friendly_name (" q_input" );
291294
292295 auto scale_node = ov::op::v0::Constant::create (ov::element::f32 , ov::Shape{1 }, {1 });
293296
294- // TODO: add sdpa subgraph
295- std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
296- q_input,
297- k_input,
298- concat_or_reshape,
299- mask_input_1,
300- scale_node,
301- false );
297+ std::shared_ptr<ov::Node> sdpa;
298+ ov::ParameterVector params = {param, param2, mask_input, k_input, q_input};
299+
300+ // SDPA with sink (6 inputs) for GPT-OSS pattern
301+ if (test.withSink ) {
302+ auto sink = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::Shape{1 , 32 , 1 , 1 });
303+ sink->set_friendly_name (" sink" );
304+ params.push_back (sink);
305+
306+ sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
307+ q_input, k_input, concat_or_reshape, mask_input_1, scale_node, sink, false );
308+ } else {
309+ sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
310+ q_input, k_input, concat_or_reshape, mask_input_1, scale_node, false );
311+ }
312+
302313 sdpa->set_friendly_name (" sdpa" );
303314 auto result = std::make_shared<ov::op::v0::Result>(sdpa);
304315
305316 result->set_friendly_name (" res" );
306- return std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param, param2, mask_input, k_input, q_input} );
317+ return std::make_shared<ov::Model>(ov::ResultVector{result}, params );
307318
308319 } else {
309320 auto param3 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::Shape{1 , 32 , 1 , input_shape[2 ] + 1 });
@@ -344,9 +355,11 @@ const std::vector<bool> withSDPA{true, false};
344355
345356const std::vector<bool > withHpAttenMask{true , false };
346357
358+ const std::vector<bool > withSink{true , false };
359+
347360const std::vector<NetworkKind> networkKind = {
348- // llama2 or llama3 type of concat, with convert layer or without
349- NetworkKind::llama2, NetworkKind::llama3
361+ NetworkKind::MHA, // Multi-Head Attention
362+ NetworkKind::GQA // Grouped Query Attention
350363};
351364
352365 INSTANTIATE_TEST_SUITE_P (smoke_Run_MatchAndTransposeVT,
@@ -357,6 +370,7 @@ const std::vector<NetworkKind> networkKind = {
357370 ::testing::ValuesIn(withBroadCast),
358371 ::testing::ValuesIn(withSDPA),
359372 ::testing::ValuesIn(withHpAttenMask),
373+ ::testing::ValuesIn(withSink),
360374 ::testing::ValuesIn(networkKind)),
361375 TransposeVTTest::getTestCaseName);
362376
0 commit comments