Skip to content

Commit e25dc47

Browse files
committed
Add PA sliding_window testcase
1 parent 00a7592 commit e25dc47

File tree

1 file changed

+93
-23
lines changed
  • src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64

1 file changed

+93
-23
lines changed

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn.cpp

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ using namespace ov::op;
3737
namespace ov {
3838
namespace test {
3939
using InputShapes = std::vector<InputShape>;
40-
using PagedAttnTestParams = std::tuple<ElementType, InputShapes, bool, bool, ov::AnyMap>;
40+
using PagedAttnTestParams = std::tuple<ElementType, InputShapes, bool, bool, int32_t, ov::AnyMap>;
4141

4242
class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams>,
4343
virtual public ov::test::SubgraphBaseTest,
4444
public CPUTestsBase {
4545
public:
4646
static std::string getTestCaseName(const testing::TestParamInfo<PagedAttnTestParams>& obj) {
47-
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = obj.param;
47+
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] = obj.param;
4848
std::ostringstream result;
4949
result << "IS=";
5050
for (const auto& shape : inputShapes) {
@@ -63,6 +63,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
6363
result << "Prc=" << inType << "_";
6464
result << "ExtendBlockIndices=" << extendBlockIndices << "_";
6565
result << "SinkInput=" << sinkInput << "_";
66+
result << "SlidingWindow=" << slidingWindow << "_";
6667
result << "config=(";
6768
for (const auto& configEntry : additional_config) {
6869
result << configEntry.first << ", " << configEntry.second.as<std::string>() << "_";
@@ -83,7 +84,8 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
8384
std::shared_ptr<ov::Model> get_model(ov::element::Type data_type,
8485
ov::Dimension::value_type head_size = 64,
8586
ov::Dimension::value_type head_num = 8,
86-
bool use_sink_input = true) {
87+
bool use_sink_input = true,
88+
int32_t sliding_window = 0) {
8789
// q [batch_in_tokens, head_num * head_size]
8890
// k [batch_in_tokens, head_num * head_size]
8991
// v [batch_in_tokens, head_num * head_size]
@@ -106,7 +108,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
106108
auto scale =
107109
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale_value});
108110
auto silding_windows =
109-
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, std::vector<int32_t>{0});
111+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, std::vector<int32_t>{sliding_window});
110112
auto alibi_slopes = std::make_shared<ov::op::v0::Constant>(ov::element::f32, Shape{0}, std::vector<float>{});
111113
auto max_context_len =
112114
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, std::vector<float>{128});
@@ -245,6 +247,10 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
245247
size_t sink_idx = use_sink_input ? 5 : -1; // sink only exists when use_sink_input=true
246248

247249
std::shared_ptr<ov::op::v13::ScaledDotProductAttention> sdp;
250+
// For sliding window case, set causal=false because we provide explicit mask with sliding window logic
251+
// For normal case, set causal=true to let SDPA apply causal mask internally
252+
bool use_causal = (sliding_window == 0);
253+
248254
if (use_sink_input) {
249255
// 7-parameter SDPA constructor with sink support
250256
// Parameters: query, key, value, attn_mask, scale, sink, causal
@@ -254,7 +260,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
254260
inputParams[atten_mask_idx],
255261
inputParams[scale_idx],
256262
inputParams[sink_idx],
257-
true);
263+
use_causal);
258264
} else {
259265
// 6-parameter SDPA constructor without sink
260266
// Parameters: query, key, value, attn_mask, scale, causal
@@ -263,7 +269,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
263269
v_in,
264270
inputParams[atten_mask_idx],
265271
inputParams[scale_idx],
266-
true);
272+
use_causal);
267273
}
268274
sdp->set_friendly_name("mha");
269275
auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
@@ -294,7 +300,8 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
294300
}
295301

296302
void SetUp() override {
297-
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = this->GetParam();
303+
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] =
304+
this->GetParam();
298305
targetDevice = ov::test::utils::DEVICE_CPU;
299306
rel_threshold = 1e-2f;
300307
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
@@ -307,7 +314,8 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
307314
init_input_shapes(inputShapes);
308315
ov::ParameterVector inputParams;
309316

310-
function = get_model(inType, 64, 8, sinkInput);
317+
this->sliding_window = slidingWindow;
318+
function = get_model(inType, 64, 8, sinkInput, slidingWindow);
311319
targetDevice = ov::test::utils::DEVICE_CPU;
312320

313321
functionRefs = get_ref_model(inType, 64, 8, sinkInput);
@@ -320,10 +328,12 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
320328
shapes.push_back(targetInputStaticShapes[0]); // q
321329
shapes.push_back(targetInputStaticShapes[0]); // k
322330
shapes.push_back(targetInputStaticShapes[0]); // v
323-
// atten_mask shape: [1, heads, seq_len, seq_len] - dynamic based on sequence length
331+
// atten_mask shape: always rectangular [1, heads, q_len, total_kv_len]
332+
// total_kv_len = past_len_count + seq_len to cover all past and current KV tokens
324333
auto seq_len = targetInputStaticShapes[0][0];
325-
shapes.push_back({1, 8, seq_len, seq_len}); // atten_mask
326-
shapes.push_back({1}); // scale
334+
size_t total_kv_len = static_cast<size_t>(past_len_count) + seq_len;
335+
shapes.push_back({1, 8, seq_len, total_kv_len}); // atten_mask (rectangular)
336+
shapes.push_back({1}); // scale
327337

328338
if (ref_model_uses_sink) {
329339
shapes.push_back({1, 8, 1, 1}); // sink
@@ -337,9 +347,13 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
337347
}
338348
template <typename IT, typename T>
339349
static void strided_iota(IT first, size_t n, T value, T stride) {
350+
// Descending order: generate values from high to low
351+
// Generate descending values to simulate attention patterns where earlier tokens have higher scores.
352+
// This is useful for testing sliding window attention mechanisms where recent context is prioritized.
340353
for (size_t i = 0; i < n; i++) {
341-
*first++ = value;
342-
value += stride;
354+
const float idx = static_cast<float>(n - 1 - i);
355+
const T generated = value + stride * static_cast<T>(idx);
356+
*first++ = generated;
343357
}
344358
}
345359
virtual void generate(int idx,
@@ -444,11 +458,39 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
444458
create_input(params[param_idx++], targetInputStaticShapes[0], idx + 2.0f); // k
445459
create_input(params[param_idx++], targetInputStaticShapes[0], idx + 3.0f); // v
446460

447-
// atten_mask - create appropriate shape based on sequence length
448-
auto seq_len = targetInputStaticShapes[0][0];
449-
create_input(params[param_idx++],
450-
{1, targetInputStaticShapes[0][2], seq_len, seq_len},
451-
0.0f); // atten_mask
461+
// atten_mask - always use rectangular mask [1, head_num, q_len, total_kv_len]
462+
auto mask_param = params[param_idx++];
463+
const size_t head_num = targetInputStaticShapes[0][2];
464+
const size_t q_len = targetInputStaticShapes[0][0];
465+
const size_t total_kv_len = static_cast<size_t>(past_len_count) + q_len;
466+
467+
if (sliding_window > 0) {
468+
// Sliding window: rectangular mask with sliding window logic
469+
ov::Tensor mask_tensor(ov::element::f32, {1, head_num, q_len, total_kv_len});
470+
auto* mask_data = mask_tensor.data<float>();
471+
const float neg_inf = -std::numeric_limits<float>::infinity();
472+
const int32_t offset = -sliding_window;
473+
474+
for (size_t h = 0; h < head_num; ++h) {
475+
for (size_t q_pos = 0; q_pos < q_len; ++q_pos) {
476+
const int32_t global_q_idx = past_len_count + static_cast<int32_t>(q_pos);
477+
for (size_t kv_idx = 0; kv_idx < total_kv_len; ++kv_idx) {
478+
const int32_t global_k_idx = static_cast<int32_t>(kv_idx);
479+
const bool within_window = global_k_idx > global_q_idx + offset;
480+
const bool causal = global_k_idx <= global_q_idx;
481+
const bool allow = within_window && causal;
482+
const size_t linear_idx = (h * q_len + q_pos) * total_kv_len + kv_idx;
483+
mask_data[linear_idx] = allow ? 0.f : neg_inf;
484+
}
485+
}
486+
}
487+
inputs[mask_param] = mask_tensor;
488+
past_len_count += static_cast<int32_t>(q_len);
489+
} else {
490+
// Normal case: rectangular mask with all zeros (no masking needed, causal handled by SDPA)
491+
create_input(mask_param, {1, head_num, q_len, total_kv_len}, 0.0f);
492+
past_len_count += static_cast<int32_t>(q_len);
493+
}
452494

453495
// scale - single value for scaling
454496
create_input(params[param_idx++], {1}, 1.0f / std::sqrt(64)); // scale
@@ -458,8 +500,13 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
458500
create_input(params[param_idx++], {1, targetInputStaticShapes[0][2], 1, 1}, 0.1f); // sink
459501
}
460502

461-
// past_kv
462-
create_input(params[param_idx++], targetInputStaticShapes[1], idx + 4.0f); // past_kv
503+
// past_kv - For SDPA with ReadValue/Assign:
504+
// - Iteration 0: empty tensor [0,1,8,64], ReadValue uses this as initial value (empty)
505+
// - Iteration 1+: should be empty [0,1,8,64], ReadValue ignores this and uses Variable state
506+
// Note: We always pass empty tensor since ReadValue/Assign manages the actual KV cache
507+
auto past_kv_shape = targetInputStaticShapes[1];
508+
past_kv_shape[0] = 0; // Always use empty past for ReadValue-based SDPA
509+
create_input(params[param_idx++], past_kv_shape, 0.0f); // past_kv (empty)
463510

464511
// beam_idx - shape matching batch dimension
465512
create_input(params[param_idx++], ov::Shape{targetInputStaticShapes[0][1]},
@@ -483,6 +530,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
483530
ov::Tensor key_cache;
484531
ov::Tensor value_cache;
485532
int32_t past_len_count = 0;
533+
int32_t sliding_window = 0;
486534
};
487535

488536
class PagedAttnVSSDPATest : public PagedAttnTestBase {
@@ -530,7 +578,7 @@ class PagedAttnVSSDPATest : public PagedAttnTestBase {
530578
std::vector<ov::Tensor> outputs;
531579
int idx = 0;
532580
for (auto&& shapes : targetStaticShapes) {
533-
generate(idx++, false, shapes, false, sinkInput); // Use the same sink input setting as the test model
581+
generate(idx++, false, shapes, false, sinkInput);
534582
for (const auto& input : inputs) {
535583
inferRequest.set_tensor(input.first, input.second);
536584
}
@@ -540,25 +588,32 @@ class PagedAttnVSSDPATest : public PagedAttnTestBase {
540588
outputTensor.copy_to(copy);
541589
outputs.push_back(copy);
542590
}
591+
reset();
543592
return outputs;
544593
}
545594
};
546595

547596
TEST_P(PagedAttnVSSDPATest, CompareWithRefs) {
548597
SKIP_IF_CURRENT_TEST_IS_DISABLED();
549-
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = this->GetParam();
598+
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] =
599+
this->GetParam();
550600
const bool isSageAttn =
551601
intel_cpu::contains_key_value(additional_config, {ov::intel_cpu::enable_sage_attn.name(), true});
552602
if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16())
553603
GTEST_SKIP();
554604
if (isSageAttn && !(ov::with_cpu_x86_avx512_core_amx_int8() || CPUTestUtils::with_cpu_x86_avx2_vnni_2()))
555605
GTEST_SKIP();
606+
607+
past_len_count = 0;
608+
556609
// compare the logits from paged attn and sdpa
557610
auto actualOutputs = run_test(function, extendBlockIndices, sinkInput);
558611
// reference model doesn't support sage attention
559612
if (isSageAttn) {
560613
configuration[ov::intel_cpu::enable_sage_attn.name()] = false;
561614
}
615+
// Reset past_len_count before running reference test to ensure consistent mask generation
616+
past_len_count = 0;
562617
auto expectedOutputs = run_ref_test(functionRefs, sinkInput);
563618
for (size_t i = 0; i < actualOutputs.size(); i++) {
564619
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
@@ -582,8 +637,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest,
582637
::testing::ValuesIn(inputShapeAndReorders),
583638
::testing::Values(true, false),
584639
::testing::Values(true, false),
640+
::testing::Values(0), // sliding_window = 0
585641
::testing::ValuesIn(additional_configs)),
586642
PagedAttnTestBase::getTestCaseName);
643+
644+
// Sliding window test with same shapes as normal test
645+
INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest_WithSlidingWindow,
646+
PagedAttnVSSDPATest,
647+
::testing::Combine(::testing::Values(ElementType::f32),
648+
::testing::ValuesIn(inputShapeAndReorders),
649+
::testing::Values(false), // extendBlockIndices
650+
::testing::Values(true), // sinkInput
651+
::testing::Values(8), // sliding_window = 8
652+
::testing::Values(ov::AnyMap{
653+
{ov::intel_cpu::enable_sage_attn.name(), false}})),
654+
PagedAttnTestBase::getTestCaseName);
587655
} // namespace
588656

589657
class PagedAttnVSMatmulTest : public PagedAttnTestBase {
@@ -786,7 +854,8 @@ class PagedAttnVSMatmulTest : public PagedAttnTestBase {
786854

787855
TEST_P(PagedAttnVSMatmulTest, CompareWithRefs) {
788856
SKIP_IF_CURRENT_TEST_IS_DISABLED();
789-
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, additional_config] = this->GetParam();
857+
const auto& [inType, inputShapes, extendBlockIndices, sinkInput, slidingWindow, additional_config] =
858+
this->GetParam();
790859
const bool isSageAttn =
791860
intel_cpu::contains_key_value(additional_config, {ov::intel_cpu::enable_sage_attn.name(), true});
792861
if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16())
@@ -821,6 +890,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSMatmulTest,
821890
::testing::ValuesIn(inputShapes),
822891
::testing::Values(true, false),
823892
::testing::Values(false),
893+
::testing::Values(0), // sliding_window = 0
824894
::testing::ValuesIn(additional_configs)),
825895
PagedAttnTestBase::getTestCaseName);
826896
} // namespace

0 commit comments

Comments
 (0)