Skip to content

Commit 9eb509b

Browse files
committed
Integrate Sparse Attention: Include Stage 1: Sparse Mask Generation and Stage 2: Sparse Attention Computation.
1 parent 8f8b243 commit 9eb509b

File tree

1 file changed

+57
-46
lines changed

1 file changed

+57
-46
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,7 +1743,7 @@ struct MHA {
17431743
score_output = _helper._score_output.template ptr<float>() + score_offset * _helper.H;
17441744
}
17451745
}
1746-
1746+
// TODO: support second token sparse attention execution
17471747
_helper.exec_kernel_one_bh(
17481748
q.slice(0, batch_in_token, batch_in_token),
17491749
k_cache,
@@ -1758,8 +1758,7 @@ struct MHA {
17581758
cur_kv_len,
17591759
alibi_slopes,
17601760
score_output,
1761-
batch_in_seq,
1762-
sparse_attention_mask);
1761+
batch_in_seq);
17631762
} else {
17641763
const auto batch_in_reorder = item.batch_in_reorder;
17651764
const auto q_blk = item.q_block_id;
@@ -1918,6 +1917,7 @@ struct MHA {
19181917
score_aggregation_window,
19191918
sparse_attention_mask);
19201919
} else {
1920+
// TODO: support second token sparse attention execution
19211921
_helper.exec_loop_bhl(query,
19221922
present_key,
19231923
present_value,
@@ -1929,8 +1929,7 @@ struct MHA {
19291929
block_indices,
19301930
block_indices_begins,
19311931
alibi_slopes,
1932-
score_aggregation_window,
1933-
sparse_attention_mask);
1932+
score_aggregation_window);
19341933
}
19351934
}
19361935
};
@@ -2159,40 +2158,42 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21592158
// TODO: enable block_size to be multiple of 32
21602159
OPENVINO_ASSERT(block_size == 32, "CPU: block size must be 32, current: ", block_size);
21612160

2161+
// TODO: use the inputs values
2162+
// PlainTensor sum;
2163+
// PlainTensor mask;
2164+
size_t xt_stride = 16;
2165+
// The original block_size of the sparse attention mask;
2166+
size_t xt_block_size = 128;
2167+
// auto xt_block_size = 32;
2168+
float xt_threshold = 0.9f;
2169+
21622170
// --- Create and initialize sparse_attention_mask ---
21632171
sparse_attention_mask.clear();
2164-
size_t k_blocks = div_up(max_context_len, block_size);
2172+
// TODO: maybe use real context_len to save memory usage
2173+
size_t k_blocks = div_up(max_context_len, xt_block_size);
21652174
for (size_t b = 0; b < B_seq; ++b) {
21662175
auto q_len_for_batch = subsequence_begins.ptr<int32_t>()[b + 1] - subsequence_begins.ptr<int32_t>()[b];
2167-
size_t q_blocks = div_up(static_cast<size_t>(q_len_for_batch), block_size);
2176+
size_t q_blocks = div_up(static_cast<size_t>(q_len_for_batch), xt_block_size);
21682177
PlainTensor mask;
21692178
mask.resize<bool>({H, q_blocks, k_blocks});
21702179
// Default initialize all to false
21712180
std::memset(mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
2172-
// Optional: activate all (example)
2173-
for (size_t h = 0; h < H; ++h) {
2174-
for (size_t q_blk = 0; q_blk < q_blocks; ++q_blk) {
2175-
for (size_t k_blk = 0; k_blk < k_blocks; ++k_blk) {
2176-
// At the symmetric point (q_blocks/2, k_blocks/2), set the upper-left and lower-right blocks
2177-
// to true, others to false bool left_top = (q_blk < q_blocks / 2) && (k_blk < k_blocks / 2);
2178-
// bool right_bottom = (q_blk >= (q_blocks + 1) / 2) && (k_blk >= (k_blocks + 1) / 2);
2179-
// mask.ptr<bool>(h, q_blk, k_blk)[0] = left_top || right_bottom;
2180-
2181-
// All masks are set to true
2182-
mask.ptr<bool>(h, q_blk, k_blk)[0] = true;
2183-
2184-
// Set the middle block to false
2185-
// if (q_blk == q_blocks / 2 && k_blk == k_blocks / 2) {
2186-
// mask.ptr<bool>(h, q_blk, k_blk)[0] = false;
2187-
// } else {
2188-
// mask.ptr<bool>(h, q_blk, k_blk)[0] = true;
2189-
// }
2190-
}
2191-
}
2192-
}
21932181
sparse_attention_mask.push_back(std::move(mask));
21942182
}
21952183

2184+
// TODO: Avoid temporary vector and assignment here. Consider changing get_sparse_blocks
2185+
// to take `std::vector<PlainTensor>& sparse_attention_mask` and fill it in-place
2186+
// to reduce PlainTensor copies (memcpy of metadata) and allocations.
2187+
sparse_attention_mask = get_sparse_blocks(q,
2188+
k,
2189+
past_lens,
2190+
subsequence_begins,
2191+
block_indices,
2192+
block_indices_begins,
2193+
xt_stride,
2194+
xt_block_size,
2195+
xt_threshold);
2196+
21962197
// --- Broadcast sparse_attention_mask to support different block sizes ---
21972198
// The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks, k_blocks]
21982199
auto broadcast_sparse_attention_mask =
@@ -2233,20 +2234,19 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22332234
};
22342235
// The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
22352236
// parameter const size_t sparse_attention_BlockSize = 128;
2236-
const size_t sparse_attention_BlockSize = 32;
22372237
// Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
22382238
// multiple
2239-
if (block_size != sparse_attention_BlockSize) {
2240-
if (block_size > sparse_attention_BlockSize) {
2241-
OPENVINO_THROW("not supported: block_size > sparse_attention_BlockSize");
2239+
if (block_size != xt_block_size) {
2240+
if (block_size > xt_block_size) {
2241+
OPENVINO_THROW("not supported: block_size > xt_block_size");
22422242
}
2243-
if (sparse_attention_BlockSize % block_size != 0) {
2244-
OPENVINO_THROW("not supported: sparse_attention_BlockSize ",
2245-
sparse_attention_BlockSize,
2243+
if (xt_block_size % block_size != 0) {
2244+
OPENVINO_THROW("not supported: xt_block_size ",
2245+
xt_block_size,
22462246
" is not an integer multiple of block_size ",
22472247
block_size);
22482248
}
2249-
broadcast_sparse_attention_mask(sparse_attention_mask, sparse_attention_BlockSize, block_size);
2249+
broadcast_sparse_attention_mask(sparse_attention_mask, xt_block_size, block_size);
22502250
}
22512251

22522252
_helper.init(H,
@@ -2314,6 +2314,27 @@ struct AttentionExecutor : public PagedAttentionExecutor {
23142314
}
23152315
}
23162316

2317+
std::vector<PlainTensor> get_sparse_blocks(PlainTensor& q,
2318+
PlainTensor& k,
2319+
PlainTensor& past_lens,
2320+
PlainTensor& subsequence_begins,
2321+
PlainTensor& block_indices,
2322+
PlainTensor& block_indices_begins,
2323+
size_t x_attention_stride,
2324+
size_t x_attention_block_size,
2325+
float threshold) {
2326+
size_t num_seqs = past_lens.size(0);
2327+
std::vector<PlainTensor> masks(num_seqs);
2328+
2329+
// TODO: support multiple batches
2330+
for (size_t seq_idx = 0; seq_idx < 1; seq_idx++) {
2331+
if (q.size(0) > 1) {
2332+
masks[seq_idx] = xattn_estimate(q, k, x_attention_block_size, x_attention_stride, 1, threshold, true);
2333+
}
2334+
}
2335+
return masks;
2336+
}
2337+
23172338
void execute(const std::vector<MemoryPtr>& inputs, const std::vector<MemoryPtr> outputs) override {
23182339
PlainTensor q;
23192340
PlainTensor k;
@@ -2371,16 +2392,6 @@ struct AttentionExecutor : public PagedAttentionExecutor {
23712392
output_score,
23722393
sparse_attention_mask);
23732394

2374-
PlainTensor sum;
2375-
PlainTensor mask;
2376-
auto stride = 16;
2377-
auto block_size = 128;
2378-
auto threshold = 0.9f;
2379-
2380-
if (q.size(0) > 1) {
2381-
mask = xattn_estimate(q, k, block_size, stride, 1, threshold, true);
2382-
}
2383-
23842395
if (rotated_block_indices) {
23852396
// Rotate kv cache currently doesn't support quantized cache.
23862397
// for u8 it only supports compilation but throws exception in the runtime

0 commit comments

Comments
 (0)