Skip to content

Commit 9a83787

Browse files
committed
fix parallel output issue
1 parent 9eb509b commit 9a83787

File tree

1 file changed

+79
-89
lines changed

1 file changed

+79
-89
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 79 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -889,17 +889,22 @@ struct MHAHelper {
889889
}
890890
}
891891

892-
// sparse attention mask: set score to -inf for (q_blk, k_blk) where
892+
// Instead of writing -inf directly into scores, build a softmax mask (0/-inf) and pass it to the kernel
893+
DATA_TYPE* softmax_mask = nullptr;
894+
std::vector<DATA_TYPE> softmax_mask_storage;
893895
if (!sparse_attention_mask.empty()) {
894-
float* score_base = _weight.ptr<float>(ithr, h - hq_beg, 0);
895-
for (size_t k_blk = 0; k_blk < cur_kv_len_blocks; k_blk++) {
896-
if (!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, k_blk)[0]) {
897-
for (size_t m = 0; m < q_cnt; m++) {
898-
float* score_blk = score_base + m * cur_kv_len_blocks * _block_size + k_blk * _block_size;
899-
std::fill(score_blk, score_blk + _block_size, -std::numeric_limits<float>::infinity());
900-
}
896+
const size_t padded_len = rnd_up(cur_kv_len, _block_size);
897+
softmax_mask_storage.resize(padded_len);
898+
// Initialize to -inf by default; then set positions for allowed blocks to 0
899+
const DATA_TYPE neg_inf_val = static_cast<DATA_TYPE>(-std::numeric_limits<float>::infinity());
900+
std::fill(softmax_mask_storage.begin(), softmax_mask_storage.end(), neg_inf_val);
901+
for (size_t k = 0; k < cur_kv_len; ++k) {
902+
size_t k_blk = k / _block_size;
903+
if (sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, k_blk)[0]) {
904+
softmax_mask_storage[k] = static_cast<DATA_TYPE>(0);
901905
}
902906
}
907+
softmax_mask = softmax_mask_storage.data();
903908
}
904909

905910
for (size_t m = q_start; m < q_end; m++) {
@@ -923,7 +928,7 @@ struct MHAHelper {
923928
reinterpret_cast<DATA_TYPE*>(score) + start_idx,
924929
revised_d_scale,
925930
alibi_lookup,
926-
nullptr,
931+
reinterpret_cast<void*>(softmax_mask + start_idx),
927932
nullptr,
928933
false,
929934
new_causal,
@@ -944,7 +949,7 @@ struct MHAHelper {
944949
reinterpret_cast<DATA_TYPE*>(score),
945950
revised_d_scale,
946951
alibi_lookup,
947-
nullptr,
952+
reinterpret_cast<void*>(softmax_mask),
948953
nullptr,
949954
false,
950955
ncausal,
@@ -1856,7 +1861,7 @@ struct MHA {
18561861
score_info_ptr,
18571862
batch_in_seq,
18581863
sparse_attention_mask);
1859-
# endif
1864+
# endif
18601865
}
18611866
});
18621867
if (output_score) {
@@ -2158,95 +2163,80 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21582163
// TODO: enable block_size to be multiple of 32
21592164
OPENVINO_ASSERT(block_size == 32, "CPU: block size must be 32, current: ", block_size);
21602165

2161-
// TODO: use the inputs values
2162-
// PlainTensor sum;
2163-
// PlainTensor mask;
21642166
size_t xt_stride = 16;
21652167
// The original block_size of the sparse attention mask;
21662168
size_t xt_block_size = 128;
21672169
// auto xt_block_size = 32;
2168-
float xt_threshold = 0.9f;
2169-
2170-
// --- Create and initialize sparse_attention_mask ---
2171-
sparse_attention_mask.clear();
2172-
// TODO: maybe use real context_len to save memory usage
2173-
size_t k_blocks = div_up(max_context_len, xt_block_size);
2174-
for (size_t b = 0; b < B_seq; ++b) {
2175-
auto q_len_for_batch = subsequence_begins.ptr<int32_t>()[b + 1] - subsequence_begins.ptr<int32_t>()[b];
2176-
size_t q_blocks = div_up(static_cast<size_t>(q_len_for_batch), xt_block_size);
2177-
PlainTensor mask;
2178-
mask.resize<bool>({H, q_blocks, k_blocks});
2179-
// Default initialize all to false
2180-
std::memset(mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
2181-
sparse_attention_mask.push_back(std::move(mask));
2182-
}
2183-
2184-
// TODO: Avoid temporary vector and assignment here. Consider changing get_sparse_blocks
2185-
// to take `std::vector<PlainTensor>& sparse_attention_mask` and fill it in-place
2186-
// to reduce PlainTensor copies (memcpy of metadata) and allocations.
2187-
sparse_attention_mask = get_sparse_blocks(q,
2188-
k,
2189-
past_lens,
2190-
subsequence_begins,
2191-
block_indices,
2192-
block_indices_begins,
2193-
xt_stride,
2194-
xt_block_size,
2195-
xt_threshold);
2196-
2197-
// --- Broadcast sparse_attention_mask to support different block sizes ---
2198-
// The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks, k_blocks]
2199-
auto broadcast_sparse_attention_mask =
2200-
[](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
2201-
if (src_block_size == dst_block_size)
2202-
return;
2203-
if (src_block_size % dst_block_size != 0) {
2204-
OPENVINO_THROW("not supported: sparse_attention_BlockSize=",
2205-
src_block_size,
2206-
" is not an integer multiple of block_size=",
2207-
dst_block_size);
2208-
}
2209-
size_t scale = src_block_size / dst_block_size;
2210-
for (auto& mask : mask_vec) {
2211-
auto shape = mask.shape();
2212-
size_t H = shape[0];
2213-
size_t q_blocks_orig = shape[1];
2214-
size_t k_blocks_orig = shape[2];
2215-
size_t q_blocks = q_blocks_orig * scale;
2216-
size_t k_blocks = k_blocks_orig * scale;
2217-
PlainTensor new_mask;
2218-
new_mask.resize<bool>({H, q_blocks, k_blocks});
2219-
std::memset(new_mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
2220-
for (size_t h = 0; h < H; ++h) {
2221-
for (size_t q_blk = 0; q_blk < q_blocks_orig; ++q_blk) {
2222-
for (size_t k_blk = 0; k_blk < k_blocks_orig; ++k_blk) {
2223-
bool val = mask.ptr<bool>(h, q_blk, k_blk)[0];
2224-
for (size_t dq = 0; dq < scale; ++dq) {
2225-
for (size_t dk = 0; dk < scale; ++dk) {
2226-
new_mask.ptr<bool>(h, q_blk * scale + dq, k_blk * scale + dk)[0] = val;
2170+
float xt_threshold = 0.6f;
2171+
// float xt_threshold = 1.0f;
2172+
2173+
// If to support second token sparse attention, need generate sparse mask after concat_pastkv
2174+
if (q.size(0) > 1) {
2175+
sparse_attention_mask = get_sparse_blocks(q,
2176+
k,
2177+
past_lens,
2178+
subsequence_begins,
2179+
block_indices,
2180+
block_indices_begins,
2181+
xt_stride,
2182+
xt_block_size,
2183+
xt_threshold);
2184+
2185+
// --- Broadcast sparse_attention_mask to support different block sizes ---
2186+
// The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks,
2187+
// k_blocks]
2188+
auto broadcast_sparse_attention_mask =
2189+
[](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
2190+
if (src_block_size == dst_block_size)
2191+
return;
2192+
if (src_block_size % dst_block_size != 0) {
2193+
OPENVINO_THROW("not supported: sparse_attention_BlockSize=",
2194+
src_block_size,
2195+
" is not an integer multiple of block_size=",
2196+
dst_block_size);
2197+
}
2198+
size_t scale = src_block_size / dst_block_size;
2199+
for (auto& mask : mask_vec) {
2200+
auto shape = mask.shape();
2201+
size_t H = shape[0];
2202+
size_t q_blocks_orig = shape[1];
2203+
size_t k_blocks_orig = shape[2];
2204+
size_t q_blocks = q_blocks_orig * scale;
2205+
size_t k_blocks = k_blocks_orig * scale;
2206+
PlainTensor new_mask;
2207+
new_mask.resize<bool>({H, q_blocks, k_blocks});
2208+
std::memset(new_mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
2209+
for (size_t h = 0; h < H; ++h) {
2210+
for (size_t q_blk = 0; q_blk < q_blocks_orig; ++q_blk) {
2211+
for (size_t k_blk = 0; k_blk < k_blocks_orig; ++k_blk) {
2212+
bool val = mask.ptr<bool>(h, q_blk, k_blk)[0];
2213+
for (size_t dq = 0; dq < scale; ++dq) {
2214+
for (size_t dk = 0; dk < scale; ++dk) {
2215+
new_mask.ptr<bool>(h, q_blk * scale + dq, k_blk * scale + dk)[0] = val;
2216+
}
22272217
}
22282218
}
22292219
}
22302220
}
2221+
mask = std::move(new_mask);
22312222
}
2232-
mask = std::move(new_mask);
2223+
};
2224+
// The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
2225+
// parameter const size_t sparse_attention_BlockSize = 128;
2226+
// Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
2227+
// multiple
2228+
if (block_size != xt_block_size) {
2229+
if (block_size > xt_block_size) {
2230+
OPENVINO_THROW("not supported: block_size > xt_block_size");
22332231
}
2234-
};
2235-
// The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
2236-
// parameter const size_t sparse_attention_BlockSize = 128;
2237-
// Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
2238-
// multiple
2239-
if (block_size != xt_block_size) {
2240-
if (block_size > xt_block_size) {
2241-
OPENVINO_THROW("not supported: block_size > xt_block_size");
2242-
}
2243-
if (xt_block_size % block_size != 0) {
2244-
OPENVINO_THROW("not supported: xt_block_size ",
2245-
xt_block_size,
2246-
" is not an integer multiple of block_size ",
2247-
block_size);
2232+
if (xt_block_size % block_size != 0) {
2233+
OPENVINO_THROW("not supported: xt_block_size ",
2234+
xt_block_size,
2235+
" is not an integer multiple of block_size ",
2236+
block_size);
2237+
}
2238+
broadcast_sparse_attention_mask(sparse_attention_mask, xt_block_size, block_size);
22482239
}
2249-
broadcast_sparse_attention_mask(sparse_attention_mask, xt_block_size, block_size);
22502240
}
22512241

22522242
_helper.init(H,

0 commit comments

Comments
 (0)