Skip to content

Commit 0a05d7c

Browse files
committed
support set ov CPU xattention parameters by genAI api
1 parent d80eede commit 0a05d7c

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,40 +2181,39 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21812181
// TODO: enable block_size to be multiple of 32
21822182
OPENVINO_ASSERT(block_size == 32, "CPU: block size must be 32, current: ", block_size);
21832183

2184-
size_t xt_stride = 16;
2184+
xattention_threshold.resize<float>({1});
2185+
xattention_threshold.ptr<float>()[0] = 0.6f;
2186+
xattention_stride = 16;
21852187
// The original block_size of the sparse attention mask;
2186-
size_t xt_block_size = 128;
2187-
// auto xt_block_size = 32;
2188-
float xt_threshold = 0.6f;
2189-
// float xt_threshold = 1.0f;
2188+
xattention_block_size = 128;
21902189

21912190
// If to support second token sparse attention, need generate sparse mask after concat_pastkv
2192-
if (q.size(0) > 1) {
2191+
if (xattention_threshold && q.size(0) > 1) {
21932192
sparse_attention_mask = get_sparse_blocks(q,
21942193
k,
21952194
past_lens,
21962195
subsequence_begins,
21972196
block_indices,
21982197
block_indices_begins,
2199-
xt_stride,
2200-
xt_block_size,
2201-
xt_threshold);
2198+
xattention_stride,
2199+
xattention_block_size,
2200+
xattention_threshold);
22022201

22032202
// Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
22042203
// multiple
2205-
if (block_size != xt_block_size) {
2206-
if (block_size > xt_block_size) {
2207-
OPENVINO_THROW("not supported: block_size > xt_block_size");
2204+
if (block_size != xattention_block_size) {
2205+
if (block_size > xattention_block_size) {
2206+
OPENVINO_THROW("not supported: block_size > xattention_block_size");
22082207
}
2209-
if (xt_block_size % block_size != 0) {
2210-
OPENVINO_THROW("not supported: xt_block_size ",
2211-
xt_block_size,
2208+
if (xattention_block_size % block_size != 0) {
2209+
OPENVINO_THROW("not supported: xattention_block_size ",
2210+
xattention_block_size,
22122211
" is not an integer multiple of block_size ",
22132212
block_size);
22142213
}
22152214
}
22162215
// keep original mask granularity; remember its block size for on-the-fly mapping
2217-
_helper._sparse_mask_block_size = xt_block_size;
2216+
_helper._sparse_mask_block_size = xattention_block_size;
22182217
}
22192218

22202219
_helper.init(H,
@@ -2290,14 +2289,20 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22902289
PlainTensor& block_indices_begins,
22912290
size_t x_attention_stride,
22922291
size_t x_attention_block_size,
2293-
float threshold) {
2292+
PlainTensor& threshold) {
22942293
size_t num_seqs = past_lens.size(0);
22952294
std::vector<PlainTensor> masks(num_seqs);
22962295

22972296
// TODO: support multiple batches
22982297
for (size_t seq_idx = 0; seq_idx < 1; seq_idx++) {
22992298
if (q.size(0) > 1) {
2300-
masks[seq_idx] = xattn_estimate(q, k, x_attention_block_size, x_attention_stride, 1, threshold, true);
2299+
masks[seq_idx] = xattn_estimate(q,
2300+
k,
2301+
x_attention_block_size,
2302+
x_attention_stride,
2303+
1,
2304+
threshold.ptr<float>()[seq_idx],
2305+
true);
23012306
}
23022307
}
23032308
return masks;

0 commit comments

Comments
 (0)