@@ -138,20 +138,23 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
138138 std::make_shared<ov::op::v0::Constant>(ov::element::i32 , Shape{}, std::vector<int32_t >{64 });
139139 auto xattention_stride =
140140 std::make_shared<ov::op::v0::Constant>(ov::element::i32 , Shape{}, std::vector<int32_t >{8 });
141- // Create sink_input parameter for testing - shape [1, num_heads, 1, 1] as per PagedAttentionExecutor::ID_SINKS
142- // PagedAttentionExtension always expects 21 inputs, so we must always include sinks parameter
143- auto sinks = make_param (PartialShape{1 , head_num, 1 , 1 }, data_type, " sinks" );
144-
145- ParameterVector params = {q,
146- k,
147- v,
148- key_cache,
149- value_cache,
150- past_lens,
151- subsequence_begins,
152- block_indices,
153- block_indices_begins,
154- sinks};
141+ // Create sink input as Constant (not Parameter) for testing
142+ // PagedAttentionExtension requires sink input to be Constant
143+ // Use shape [1, head_num, 1, 1] when use_sink_input=true, or empty shape [0] when false
144+ std::shared_ptr<ov::op::v0::Constant> sinks;
145+ if (use_sink_input) {
146+ // Create real sink tokens for testing sink functionality
147+ std::vector<float > sink_data (static_cast <size_t >(head_num), 0 .1f );
148+ sinks = std::make_shared<ov::op::v0::Constant>(data_type,
149+ Shape{1 , static_cast <size_t >(head_num), 1 , 1 },
150+ sink_data);
151+ } else {
152+ // Create empty sink (matching SDPA->PA transformation behavior when no sink)
153+ sinks = std::make_shared<ov::op::v0::Constant>(data_type, Shape{0 }, std::vector<float >{});
154+ }
155+
156+ ParameterVector params =
157+ {q, k, v, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins};
155158 OutputVector paged_attn_inputs = {q,
156159 k,
157160 v,
@@ -461,10 +464,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
461464 inputs.insert ({function->get_parameters ()[7 ], block_indices});
462465 inputs.insert ({function->get_parameters ()[8 ], block_indices_begins});
463466
464- // Create sink_input data - shape [1, num_heads, 1, 1] as per PagedAttentionExtensor specification
465- // Always create the sink input data since PagedAttentionExtension expects 21 inputs
466- // The value will be ignored when use_sink_input=false
467- create_input (function->get_parameters ()[9 ], {1 , qkv_shape[2 ], 1 , 1 }, use_sink_input ? 0 .1f : 0 .0f );
467+ // Note: sink is a Constant in the model, not a Parameter, so no need to provide input for it
468468
469469 past_len_count += static_cast <int32_t >(qkv_shape[0 ]);
470470
@@ -655,21 +655,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest,
655655 ::testing::Combine (::testing::Values(ElementType::f32 , ElementType::bf16 ),
656656 ::testing::ValuesIn(inputShapeAndReorders),
657657 ::testing::Values(true , false ),
658+ // TODO: Xattn should not direcctly compare with SDPA/decomposed Matmul
659+ // which not contain sparse logics
658660 ::testing::Values(true , false ),
659- ::testing::Values(true , false ),
660- ::testing::Values(0 ), // sliding_window = 0
661+ ::testing::Values(false ),
662+ ::testing::Values(0 ),
661663 ::testing::ValuesIn(additional_configs)),
662664 PagedAttnTestBase::getTestCaseName);
663665
664- // Sliding window test with same shapes as normal test
665- INSTANTIATE_TEST_SUITE_P (smoke_PagedAttnVSSDPATest_WithSlidingWindow,
666+ INSTANTIATE_TEST_SUITE_P (smoke_PagedAttnVSSDPATest_WithSlidingWindowAndSinks,
666667 PagedAttnVSSDPATest,
667668 ::testing::Combine (::testing::Values(ElementType::f32 ),
668669 ::testing::ValuesIn(inputShapeAndReorders),
669- ::testing::Values(false ), // extendBlockIndices
670- ::testing::Values(false ), // enableXattn
671- ::testing::Values(true ), // sinkInput
672- ::testing::Values(8 ), // sliding_window = 8
670+ ::testing::Values(false ), // extendBlockIndices
671+ ::testing::Values(false ), // enableXattn
672+ ::testing::Values(true , false ), // sinkInput
673+ ::testing::Values(0 , 8 ), // sliding_window = 8
673674 ::testing::Values(ov::AnyMap{
674675 {ov::intel_cpu::enable_sage_attn.name (), false }})),
675676 PagedAttnTestBase::getTestCaseName);
0 commit comments