Skip to content

Commit 0b85b40

Browse files
committed
keep PA sink logic focus on constant with shape {1, H, 1, 1} to align with transofrm process
1 parent d9cd8d3 commit 0b85b40

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ struct MHAHelper {
926926
// sink processing is independent of sliding_window size
927927
float* sink = nullptr;
928928
if (sinks) {
929-
sink = &sinks.at<float>({batch_in_seq, h, m, 0}, true);
929+
sink = &sinks.at<float>({0, h, 0, 0}, true);
930930
}
931931
if (_sliding_window) {
932932
size_t start_idx = 0;
@@ -1227,7 +1227,6 @@ struct MHAHelper {
12271227
size_t cur_kv_len,
12281228
const PlainTensor& alibi_slopes,
12291229
float* score_output,
1230-
size_t batch_in_seq,
12311230
const PlainTensor& sinks) {
12321231
# if defined(OPENVINO_ARCH_X86_64)
12331232
if (any_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
@@ -1288,7 +1287,7 @@ struct MHAHelper {
12881287
}
12891288
float* sink = nullptr;
12901289
if (sinks) {
1291-
sink = &sinks.at<float>({batch_in_seq, h, pq, 0}, true);
1290+
sink = &sinks.at<float>({0, h, 0, 0}, true);
12921291
}
12931292
if (_sliding_window) {
12941293
size_t start_idx = 0;
@@ -1498,7 +1497,7 @@ struct MHAHelper {
14981497
}
14991498
float* sink = nullptr;
15001499
if (sinks) {
1501-
sink = &sinks.at<float>({b, h, pq, 0}, true);
1500+
sink = &sinks.at<float>({0, h, 0, 0}, true);
15021501
}
15031502
if (_sliding_window) {
15041503
size_t start_idx = 0;
@@ -1808,7 +1807,6 @@ struct MHA {
18081807
cur_kv_len,
18091808
alibi_slopes,
18101809
score_output,
1811-
batch_in_seq,
18121810
sinks);
18131811
} else {
18141812
const auto batch_in_reorder = item.batch_in_reorder;

src/plugins/intel_cpu/src/nodes/paged_attn.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
#include "openvino/core/except.hpp"
2626
#include "openvino/core/node.hpp"
2727
#include "openvino/core/type/element_type.hpp"
28+
#include "openvino/op/constant.hpp"
2829
#include "openvino/runtime/system_conf.hpp"
2930
#include "shape_inference/shape_inference_internal_dyn.hpp"
31+
#include "transformations/utils/utils.hpp"
3032
#include "utils/general_utils.h"
3133

3234
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) || defined(OPENVINO_ARCH_ARM64)
@@ -302,7 +304,20 @@ bool PagedAttention::isSupportedOperation(const std::shared_ptr<const ov::Node>&
302304
}
303305
auto orgInput = static_cast<int>(op->get_input_size());
304306
if (op->get_type_name() == std::string("PagedAttentionExtension") &&
305-
orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) {
307+
orgInput == PagedAttentionExecutor::ID_SINKS + 1) {
308+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(PagedAttentionExecutor::ID_SINKS))) {
309+
errorMessage = "Only Constant operation on sink input is supported";
310+
return false;
311+
}
312+
#if defined(OPENVINO_ARCH_ARM64)
313+
// ARM platform doesn't support non-empty sink input yet
314+
// Check if sink input is non-empty (shape size > 0)
315+
const auto& sink_shape = op->get_input_partial_shape(PagedAttentionExecutor::ID_SINKS);
316+
if (sink_shape.is_static() && ov::shape_size(sink_shape.to_shape()) > 0) {
317+
errorMessage = "PagedAttentionExtension with non-empty sink input is not supported on ARM platform";
318+
return false;
319+
}
320+
#endif
306321
return true;
307322
}
308323
} catch (...) {

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -138,20 +138,23 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
138138
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, std::vector<int32_t>{64});
139139
auto xattention_stride =
140140
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, std::vector<int32_t>{8});
141-
// Create sink_input parameter for testing - shape [1, num_heads, 1, 1] as per PagedAttentionExecutor::ID_SINKS
142-
// PagedAttentionExtension always expects 21 inputs, so we must always include sinks parameter
143-
auto sinks = make_param(PartialShape{1, head_num, 1, 1}, data_type, "sinks");
144-
145-
ParameterVector params = {q,
146-
k,
147-
v,
148-
key_cache,
149-
value_cache,
150-
past_lens,
151-
subsequence_begins,
152-
block_indices,
153-
block_indices_begins,
154-
sinks};
141+
// Create sink input as Constant (not Parameter) for testing
142+
// PagedAttentionExtension requires sink input to be Constant
143+
// Use shape [1, head_num, 1, 1] when use_sink_input=true, or empty shape [0] when false
144+
std::shared_ptr<ov::op::v0::Constant> sinks;
145+
if (use_sink_input) {
146+
// Create real sink tokens for testing sink functionality
147+
std::vector<float> sink_data(static_cast<size_t>(head_num), 0.1f);
148+
sinks = std::make_shared<ov::op::v0::Constant>(data_type,
149+
Shape{1, static_cast<size_t>(head_num), 1, 1},
150+
sink_data);
151+
} else {
152+
// Create empty sink (matching SDPA->PA transformation behavior when no sink)
153+
sinks = std::make_shared<ov::op::v0::Constant>(data_type, Shape{0}, std::vector<float>{});
154+
}
155+
156+
ParameterVector params =
157+
{q, k, v, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins};
155158
OutputVector paged_attn_inputs = {q,
156159
k,
157160
v,
@@ -461,10 +464,7 @@ class PagedAttnTestBase : public testing::WithParamInterface<PagedAttnTestParams
461464
inputs.insert({function->get_parameters()[7], block_indices});
462465
inputs.insert({function->get_parameters()[8], block_indices_begins});
463466

464-
// Create sink_input data - shape [1, num_heads, 1, 1] as per PagedAttentionExtensor specification
465-
// Always create the sink input data since PagedAttentionExtension expects 21 inputs
466-
// The value will be ignored when use_sink_input=false
467-
create_input(function->get_parameters()[9], {1, qkv_shape[2], 1, 1}, use_sink_input ? 0.1f : 0.0f);
467+
// Note: sink is a Constant in the model, not a Parameter, so no need to provide input for it
468468

469469
past_len_count += static_cast<int32_t>(qkv_shape[0]);
470470

@@ -655,21 +655,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest,
655655
::testing::Combine(::testing::Values(ElementType::f32, ElementType::bf16),
656656
::testing::ValuesIn(inputShapeAndReorders),
657657
::testing::Values(true, false),
658+
// TODO: Xattn should not direcctly compare with SDPA/decomposed Matmul
659+
// which not contain sparse logics
658660
::testing::Values(true, false),
659-
::testing::Values(true, false),
660-
::testing::Values(0), // sliding_window = 0
661+
::testing::Values(false),
662+
::testing::Values(0),
661663
::testing::ValuesIn(additional_configs)),
662664
PagedAttnTestBase::getTestCaseName);
663665

664-
// Sliding window test with same shapes as normal test
665-
INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest_WithSlidingWindow,
666+
INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnVSSDPATest_WithSlidingWindowAndSinks,
666667
PagedAttnVSSDPATest,
667668
::testing::Combine(::testing::Values(ElementType::f32),
668669
::testing::ValuesIn(inputShapeAndReorders),
669-
::testing::Values(false), // extendBlockIndices
670-
::testing::Values(false), // enableXattn
671-
::testing::Values(true), // sinkInput
672-
::testing::Values(8), // sliding_window = 8
670+
::testing::Values(false), // extendBlockIndices
671+
::testing::Values(false), // enableXattn
672+
::testing::Values(true, false), // sinkInput
673+
::testing::Values(0, 8), // sliding_window = 8
673674
::testing::Values(ov::AnyMap{
674675
{ov::intel_cpu::enable_sage_attn.name(), false}})),
675676
PagedAttnTestBase::getTestCaseName);

0 commit comments

Comments
 (0)