@@ -37,14 +37,14 @@ using namespace ov::op;
3737namespace ov {
3838namespace test {
3939using InputShapes = std::vector<InputShape>;
40- using PagedAttnTestParams = std::tuple<ElementType, InputShapes, bool , bool , ov::AnyMap>;
40+ using PagedAttnTestParams = std::tuple<ElementType, InputShapes, bool , bool , int32_t , ov::AnyMap>;
4141
4242class PagedAttnTestBase : public testing ::WithParamInterface<PagedAttnTestParams>,
4343 virtual public ov::test::SubgraphBaseTest,
4444 public CPUTestsBase {
4545public:
4646 static std::string getTestCaseName (const testing::TestParamInfo<PagedAttnTestParams>& obj) {
47- const auto & [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = obj.param ;
47+ const auto & [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] = obj.param ;
4848 std::ostringstream result;
4949 result << " IS=" ;
5050 for (const auto & shape : inputShapes) {
@@ -63,6 +63,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
6363 result << " Prc=" << inType << " _" ;
6464 result << " ExtendBlockIndices=" << extendBlockIndices << " _" ;
6565 result << " SinkInput=" << sinkInput << " _" ;
66+ result << " SlidingWindow=" << slidingWindow << " _" ;
6667 result << " config=(" ;
6768 for (const auto & configEntry : additional_config) {
6869 result << configEntry.first << " , " << configEntry.second .as <std::string>() << " _" ;
@@ -83,7 +84,8 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
8384 std::shared_ptr<ov::Model> get_model (ov::element::Type data_type,
8485 ov::Dimension::value_type head_size = 64 ,
8586 ov::Dimension::value_type head_num = 8 ,
86- bool use_sink_input = true ) {
87+ bool use_sink_input = true ,
88+ int32_t sliding_window = 0 ) {
8789 // q [batch_in_tokens, head_num * head_size]
8890 // k [batch_in_tokens, head_num * head_size]
8991 // v [batch_in_tokens, head_num * head_size]
@@ -106,7 +108,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
106108 auto scale =
107109 std::make_shared<ov::op::v0::Constant>(ov::element::f32 , ov::Shape{}, std::vector<float >{scale_value});
108110 auto silding_windows =
109- std::make_shared<ov::op::v0::Constant>(ov::element::i32 , Shape{}, std::vector<int32_t >{0 });
111+ std::make_shared<ov::op::v0::Constant>(ov::element::i32 , Shape{}, std::vector<int32_t >{sliding_window });
110112 auto alibi_slopes = std::make_shared<ov::op::v0::Constant>(ov::element::f32 , Shape{0 }, std::vector<float >{});
111113 auto max_context_len =
112114 std::make_shared<ov::op::v0::Constant>(ov::element::i32 , Shape{}, std::vector<float >{128 });
@@ -245,6 +247,10 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
245247 size_t sink_idx = use_sink_input ? 5 : -1 ; // sink only exists when use_sink_input=true
246248
247249 std::shared_ptr<ov::op::v13::ScaledDotProductAttention> sdp;
250+ // For sliding window case, set causal=false because we provide explicit mask with sliding window logic
251+ // For normal case, set causal=true to let SDPA apply causal mask internally
252+ bool use_causal = (sliding_window == 0 );
253+
248254 if (use_sink_input) {
249255 // 7-parameter SDPA constructor with sink support
250256 // Parameters: query, key, value, attn_mask, scale, sink, causal
@@ -254,7 +260,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
254260 inputParams[atten_mask_idx],
255261 inputParams[scale_idx],
256262 inputParams[sink_idx],
257- true );
263+ use_causal );
258264 } else {
259265 // 6-parameter SDPA constructor without sink
260266 // Parameters: query, key, value, attn_mask, scale, causal
@@ -263,7 +269,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
263269 v_in,
264270 inputParams[atten_mask_idx],
265271 inputParams[scale_idx],
266- true );
272+ use_causal );
267273 }
268274 sdp->set_friendly_name (" mha" );
269275 auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
@@ -294,7 +300,8 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
294300 }
295301
296302 void SetUp () override {
297- const auto & [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = this ->GetParam ();
303+ const auto & [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] =
304+ this ->GetParam ();
298305 targetDevice = ov::test::utils::DEVICE_CPU;
299306 rel_threshold = 1e-2f ;
300307 configuration[ov::hint::inference_precision.name ()] = ov::element::f32 ;
@@ -307,7 +314,8 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
307314 init_input_shapes (inputShapes);
308315 ov::ParameterVector inputParams;
309316
310- function = get_model (inType, 64 , 8 , sinkInput);
317+ this ->sliding_window = slidingWindow;
318+ function = get_model (inType, 64 , 8 , sinkInput, slidingWindow);
311319 targetDevice = ov::test::utils::DEVICE_CPU;
312320
313321 functionRefs = get_ref_model (inType, 64 , 8 , sinkInput);
@@ -320,10 +328,12 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
320328 shapes.push_back (targetInputStaticShapes[0 ]); // q
321329 shapes.push_back (targetInputStaticShapes[0 ]); // k
322330 shapes.push_back (targetInputStaticShapes[0 ]); // v
323- // atten_mask shape: [1, heads, seq_len, seq_len] - dynamic based on sequence length
331+ // atten_mask shape: always rectangular [1, heads, q_len, total_kv_len]
332+ // total_kv_len = past_len_count + seq_len to cover all past and current KV tokens
324333 auto seq_len = targetInputStaticShapes[0 ][0 ];
325- shapes.push_back ({1 , 8 , seq_len, seq_len}); // atten_mask
326- shapes.push_back ({1 }); // scale
334+ size_t total_kv_len = static_cast <size_t >(past_len_count) + seq_len;
335+ shapes.push_back ({1 , 8 , seq_len, total_kv_len}); // atten_mask (rectangular)
336+ shapes.push_back ({1 }); // scale
327337
328338 if (ref_model_uses_sink) {
329339 shapes.push_back ({1 , 8 , 1 , 1 }); // sink
@@ -337,9 +347,13 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
337347 }
338348 template <typename IT, typename T>
339349 static void strided_iota (IT first, size_t n, T value, T stride) {
350+ // Descending order: generate values from high to low
351+ // Generate descending values to simulate attention patterns where earlier tokens have higher scores.
352+ // This is useful for testing sliding window attention mechanisms where recent context is prioritized.
340353 for (size_t i = 0 ; i < n; i++) {
341- *first++ = value;
342- value += stride;
354+ const float idx = static_cast <float >(n - 1 - i);
355+ const T generated = value + stride * static_cast <T>(idx);
356+ *first++ = generated;
343357 }
344358 }
345359 virtual void generate (int idx,
@@ -444,11 +458,39 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
444458 create_input (params[param_idx++], targetInputStaticShapes[0 ], idx + 2 .0f ); // k
445459 create_input (params[param_idx++], targetInputStaticShapes[0 ], idx + 3 .0f ); // v
446460
447- // atten_mask - create appropriate shape based on sequence length
448- auto seq_len = targetInputStaticShapes[0 ][0 ];
449- create_input (params[param_idx++],
450- {1 , targetInputStaticShapes[0 ][2 ], seq_len, seq_len},
451- 0 .0f ); // atten_mask
461+ // atten_mask - always use rectangular mask [1, head_num, q_len, total_kv_len]
462+ auto mask_param = params[param_idx++];
463+ const size_t head_num = targetInputStaticShapes[0 ][2 ];
464+ const size_t q_len = targetInputStaticShapes[0 ][0 ];
465+ const size_t total_kv_len = static_cast <size_t >(past_len_count) + q_len;
466+
467+ if (sliding_window > 0 ) {
468+ // Sliding window: rectangular mask with sliding window logic
469+ ov::Tensor mask_tensor (ov::element::f32 , {1 , head_num, q_len, total_kv_len});
470+ auto * mask_data = mask_tensor.data <float >();
471+ const float neg_inf = -std::numeric_limits<float >::infinity ();
472+ const int32_t offset = -sliding_window;
473+
474+ for (size_t h = 0 ; h < head_num; ++h) {
475+ for (size_t q_pos = 0 ; q_pos < q_len; ++q_pos) {
476+ const int32_t global_q_idx = past_len_count + static_cast <int32_t >(q_pos);
477+ for (size_t kv_idx = 0 ; kv_idx < total_kv_len; ++kv_idx) {
478+ const int32_t global_k_idx = static_cast <int32_t >(kv_idx);
479+ const bool within_window = global_k_idx > global_q_idx + offset;
480+ const bool causal = global_k_idx <= global_q_idx;
481+ const bool allow = within_window && causal;
482+ const size_t linear_idx = (h * q_len + q_pos) * total_kv_len + kv_idx;
483+ mask_data[linear_idx] = allow ? 0 .f : neg_inf;
484+ }
485+ }
486+ }
487+ inputs[mask_param] = mask_tensor;
488+ past_len_count += static_cast <int32_t >(q_len);
489+ } else {
490+ // Normal case: rectangular mask with all zeros (no masking needed, causal handled by SDPA)
491+ create_input (mask_param, {1 , head_num, q_len, total_kv_len}, 0 .0f );
492+ past_len_count += static_cast <int32_t >(q_len);
493+ }
452494
453495 // scale - single value for scaling
454496 create_input (params[param_idx++], {1 }, 1 .0f / std::sqrt (64 )); // scale
@@ -458,8 +500,13 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
458500 create_input (params[param_idx++], {1 , targetInputStaticShapes[0 ][2 ], 1 , 1 }, 0 .1f ); // sink
459501 }
460502
461- // past_kv
462- create_input (params[param_idx++], targetInputStaticShapes[1 ], idx + 4 .0f ); // past_kv
503+ // past_kv - For SDPA with ReadValue/Assign:
504+ // - Iteration 0: empty tensor [0,1,8,64], ReadValue uses this as initial value (empty)
505+ // - Iteration 1+: should be empty [0,1,8,64], ReadValue ignores this and uses Variable state
506+ // Note: We always pass empty tensor since ReadValue/Assign manages the actual KV cache
507+ auto past_kv_shape = targetInputStaticShapes[1 ];
508+ past_kv_shape[0 ] = 0 ; // Always use empty past for ReadValue-based SDPA
509+ create_input (params[param_idx++], past_kv_shape, 0 .0f ); // past_kv (empty)
463510
464511 // beam_idx - shape matching batch dimension
465512 create_input (params[param_idx++], ov::Shape{targetInputStaticShapes[0 ][1 ]},
@@ -483,6 +530,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
483530 ov::Tensor key_cache;
484531 ov::Tensor value_cache;
485532 int32_t past_len_count = 0 ;
533+ int32_t sliding_window = 0 ;
486534};
487535
488536class PagedAttnVSSDPATest : public PagedAttnTestBase {
@@ -530,7 +578,7 @@ class PagedAttnVSSDPATest : public PagedAttnTestBase {
530578 std::vector<ov::Tensor> outputs;
531579 int idx = 0 ;
532580 for (auto && shapes : targetStaticShapes) {
533- generate (idx++, false , shapes, false , sinkInput); // Use the same sink input setting as the test model
581+ generate (idx++, false , shapes, false , sinkInput);
534582 for (const auto & input : inputs) {
535583 inferRequest.set_tensor (input.first , input.second );
536584 }
@@ -540,25 +588,32 @@ class PagedAttnVSSDPATest : public PagedAttnTestBase {
540588 outputTensor.copy_to (copy);
541589 outputs.push_back (copy);
542590 }
591+ reset ();
543592 return outputs;
544593 }
545594};
546595
547596TEST_P (PagedAttnVSSDPATest, CompareWithRefs) {
548597 SKIP_IF_CURRENT_TEST_IS_DISABLED ();
549- const auto & [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = this ->GetParam ();
598+ const auto & [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] =
599+ this ->GetParam ();
550600 const bool isSageAttn =
551601 intel_cpu::contains_key_value (additional_config, {ov::intel_cpu::enable_sage_attn.name (), true });
552602 if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16 ())
553603 GTEST_SKIP ();
554604 if (isSageAttn && !(ov::with_cpu_x86_avx512_core_amx_int8 () || CPUTestUtils::with_cpu_x86_avx2_vnni_2 ()))
555605 GTEST_SKIP ();
606+
607+ past_len_count = 0 ;
608+
556609 // compare the logits from paged attn and sdpa
557610 auto actualOutputs = run_test (function, extendBlockIndices, sinkInput);
558611 // reference model doesn't support sage attention
559612 if (isSageAttn) {
560613 configuration[ov::intel_cpu::enable_sage_attn.name ()] = false ;
561614 }
615+ // Reset past_len_count before running reference test to ensure consistent mask generation
616+ past_len_count = 0 ;
562617 auto expectedOutputs = run_ref_test (functionRefs, sinkInput);
563618 for (size_t i = 0 ; i < actualOutputs.size (); i++) {
564619 ov::test::utils::compare (expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
@@ -582,8 +637,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest,
582637 ::testing::ValuesIn(inputShapeAndReorders),
583638 ::testing::Values(true , false ),
584639 ::testing::Values(true , false ),
640+ ::testing::Values(0 ), // sliding_window = 0
585641 ::testing::ValuesIn(additional_configs)),
586642 PagedAttnTestBase::getTestCaseName);
643+
644+ // Sliding window test with same shapes as normal test
645+ INSTANTIATE_TEST_SUITE_P (smoke_PagedAttnVSSDPATest_WithSlidingWindow,
646+ PagedAttnVSSDPATest,
647+ ::testing::Combine (::testing::Values(ElementType::f32 ),
648+ ::testing::ValuesIn(inputShapeAndReorders),
649+ ::testing::Values(false ), // extendBlockIndices
650+ ::testing::Values(true ), // sinkInput
651+ ::testing::Values(8 ), // sliding_window = 8
652+ ::testing::Values(ov::AnyMap{
653+ {ov::intel_cpu::enable_sage_attn.name (), false }})),
654+ PagedAttnTestBase::getTestCaseName);
587655} // namespace
588656
589657class PagedAttnVSMatmulTest : public PagedAttnTestBase {
@@ -786,7 +854,8 @@ class PagedAttnVSMatmulTest : public PagedAttnTestBase {
786854
787855TEST_P (PagedAttnVSMatmulTest, CompareWithRefs) {
788856 SKIP_IF_CURRENT_TEST_IS_DISABLED ();
789- const auto & [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = this ->GetParam ();
857+ const auto & [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] =
858+ this ->GetParam ();
790859 const bool isSageAttn =
791860 intel_cpu::contains_key_value (additional_config, {ov::intel_cpu::enable_sage_attn.name (), true });
792861 if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16 ())
@@ -821,6 +890,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSMatmulTest,
821890 ::testing::ValuesIn(inputShapes),
822891 ::testing::Values(true , false ),
823892 ::testing::Values(false ),
893+ ::testing::Values(0 ), // sliding_window = 0
824894 ::testing::ValuesIn(additional_configs)),
825895 PagedAttnTestBase::getTestCaseName);
826896} // namespace
0 commit comments