Skip to content

Commit d96bc30

Browse files
committed
optimize sparse mask broadcast: use block index map
1 parent 9a83787 commit d96bc30

File tree

1 file changed

+28
-50
lines changed

1 file changed

+28
-50
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,8 @@ struct MHAHelper {
584584
std::vector<ScoreAggregationInfo> _score_infos;
585585

586586
PlainTensor _block_rotation_coefficient_scratch;
587+
// Block size used when generating sparse_attention_mask (0 means unspecified/equal to _block_size)
588+
size_t _sparse_mask_block_size = 0;
587589

588590
MHAHelper() {
589591
_weight.resize<float>({size_t{1}, size_t{1}, size_t{1}, size_t{1}});
@@ -853,12 +855,25 @@ struct MHAHelper {
853855
// 1 1 0 0 ...
854856
// 1 1 1 0 ...
855857
// just computing the positions of 1 should be enough
858+
// map runtime (block_size) indices to mask (xt_block_size) indices
859+
auto map_to_mask_idx = [&](size_t q_blk_rt, size_t k_blk_rt) {
860+
if (_sparse_mask_block_size == 0 || _sparse_mask_block_size == _block_size) {
861+
return std::pair<size_t, size_t>{q_blk_rt, k_blk_rt};
862+
}
863+
// Only support mask block >= runtime block and divisible (checked in init)
864+
size_t scale = _sparse_mask_block_size / _block_size; // >=1
865+
size_t q_mask = q_blk_rt / scale;
866+
size_t k_mask = k_blk_rt / scale;
867+
return std::pair<size_t, size_t>{q_mask, k_mask};
868+
};
856869
for (size_t k_blk = 0; k_blk < cur_kv_len_blocks; k_blk++) {
857870
// sparse attention mask filtering
858-
if (!sparse_attention_mask.empty() &&
859-
!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, k_blk)[0]) {
860-
// Skip GEMM for this block if mask is false
861-
continue;
871+
if (!sparse_attention_mask.empty()) {
872+
auto [q_m, k_m] = map_to_mask_idx(q_blk, k_blk);
873+
if (!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_m, k_m)[0]) {
874+
// Skip GEMM for this block if mask is false
875+
continue;
876+
}
862877
}
863878
if (_params.is_sage_attn) {
864879
# if defined(OPENVINO_ARCH_X86_64)
@@ -900,7 +915,8 @@ struct MHAHelper {
900915
std::fill(softmax_mask_storage.begin(), softmax_mask_storage.end(), neg_inf_val);
901916
for (size_t k = 0; k < cur_kv_len; ++k) {
902917
size_t k_blk = k / _block_size;
903-
if (sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, k_blk)[0]) {
918+
auto [q_m, k_m] = map_to_mask_idx(q_blk, k_blk);
919+
if (sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_m, k_m)[0]) {
904920
softmax_mask_storage[k] = static_cast<DATA_TYPE>(0);
905921
}
906922
}
@@ -980,9 +996,11 @@ struct MHAHelper {
980996
// for each weight block, loop through all value block
981997
for (size_t v_blk = 0; v_blk < cur_kv_len_blocks; v_blk++) {
982998
// sparse attention mask filtering for value blocks
983-
if (!sparse_attention_mask.empty() &&
984-
!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, v_blk)[0]) {
985-
continue;
999+
if (!sparse_attention_mask.empty()) {
1000+
auto [q_m, v_m] = map_to_mask_idx(q_blk, v_blk);
1001+
if (!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_m, v_m)[0]) {
1002+
continue;
1003+
}
9861004
}
9871005
DATA_TYPE* v_ptr = nullptr;
9881006
if (q_is_xf16 || !q_cache_is_same) {
@@ -2182,47 +2200,6 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21822200
xt_block_size,
21832201
xt_threshold);
21842202

2185-
// --- Broadcast sparse_attention_mask to support different block sizes ---
2186-
// The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks,
2187-
// k_blocks]
2188-
auto broadcast_sparse_attention_mask =
2189-
[](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
2190-
if (src_block_size == dst_block_size)
2191-
return;
2192-
if (src_block_size % dst_block_size != 0) {
2193-
OPENVINO_THROW("not supported: sparse_attention_BlockSize=",
2194-
src_block_size,
2195-
" is not an integer multiple of block_size=",
2196-
dst_block_size);
2197-
}
2198-
size_t scale = src_block_size / dst_block_size;
2199-
for (auto& mask : mask_vec) {
2200-
auto shape = mask.shape();
2201-
size_t H = shape[0];
2202-
size_t q_blocks_orig = shape[1];
2203-
size_t k_blocks_orig = shape[2];
2204-
size_t q_blocks = q_blocks_orig * scale;
2205-
size_t k_blocks = k_blocks_orig * scale;
2206-
PlainTensor new_mask;
2207-
new_mask.resize<bool>({H, q_blocks, k_blocks});
2208-
std::memset(new_mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
2209-
for (size_t h = 0; h < H; ++h) {
2210-
for (size_t q_blk = 0; q_blk < q_blocks_orig; ++q_blk) {
2211-
for (size_t k_blk = 0; k_blk < k_blocks_orig; ++k_blk) {
2212-
bool val = mask.ptr<bool>(h, q_blk, k_blk)[0];
2213-
for (size_t dq = 0; dq < scale; ++dq) {
2214-
for (size_t dk = 0; dk < scale; ++dk) {
2215-
new_mask.ptr<bool>(h, q_blk * scale + dq, k_blk * scale + dk)[0] = val;
2216-
}
2217-
}
2218-
}
2219-
}
2220-
}
2221-
mask = std::move(new_mask);
2222-
}
2223-
};
2224-
// The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
2225-
// parameter const size_t sparse_attention_BlockSize = 128;
22262203
// Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
22272204
// multiple
22282205
if (block_size != xt_block_size) {
@@ -2235,8 +2212,9 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22352212
" is not an integer multiple of block_size ",
22362213
block_size);
22372214
}
2238-
broadcast_sparse_attention_mask(sparse_attention_mask, xt_block_size, block_size);
22392215
}
2216+
// keep original mask granularity; remember its block size for on-the-fly mapping
2217+
_helper._sparse_mask_block_size = xt_block_size;
22402218
}
22412219

22422220
_helper.init(H,

0 commit comments

Comments
 (0)