Skip to content

Commit 8156e70

Browse files
committed
exec_loop_mixed(exec_kernel_one_bh and exec_kernel_multiple) support
1 parent 8d7a97e commit 8156e70

File tree

2 files changed

+267
-48
lines changed

2 files changed

+267
-48
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 209 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,7 +2470,9 @@ struct MHAHelper {
24702470
const PlainTensor& alibi_slopes,
24712471
float* score_output,
24722472
size_t q_start_idx_score,
2473-
const ScoreAggregationInfo* score_info_ptr) {
2473+
const ScoreAggregationInfo* score_info_ptr,
2474+
size_t batch_in_seq = 0,
2475+
const std::vector<PlainTensor>& sparse_attention_mask = {}) {
24742476
auto q_start = q_blk * _block_size;
24752477
auto q_end = std::min(q_start + _block_size, q_len);
24762478
auto q_cnt = q_end - q_start;
@@ -2487,6 +2489,12 @@ struct MHAHelper {
24872489
// 1 1 1 0 ...
24882490
// just computing the positions of 1 should be enough
24892491
for (size_t k_blk = 0; k_blk < cur_kv_len_blocks; k_blk++) {
2492+
// sparse attention mask filtering
2493+
if (!sparse_attention_mask.empty() &&
2494+
!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, k_blk)[0]) {
2495+
// mask为false时跳过该block的GEMM
2496+
continue;
2497+
}
24902498
auto* k_ptr = qk_scratch_b.ptr<DATA_TYPE>(k_blk, hk);
24912499
_qk_gemm[q_cnt - 1]->executeGemm(q_cnt < _block_size,
24922500
q_ptr,
@@ -2496,6 +2504,19 @@ struct MHAHelper {
24962504
_qk_scratch_a ? _qk_scratch_a.ptr<DATA_TYPE>(ithr, 0) : nullptr);
24972505
}
24982506

2507+
// sparse attention mask: 对应q_blk, k_blk的mask为false时,score置为-inf(整个q_blk的score buffer)
2508+
if (!sparse_attention_mask.empty()) {
2509+
float* score_base = _weight.ptr<float>(ithr, h - hq_beg, 0);
2510+
for (size_t k_blk = 0; k_blk < cur_kv_len_blocks; k_blk++) {
2511+
if (!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, k_blk)[0]) {
2512+
for (size_t m = 0; m < q_cnt; m++) {
2513+
float* score_blk = score_base + m * cur_kv_len_blocks * _block_size + k_blk * _block_size;
2514+
std::fill(score_blk, score_blk + _block_size, -std::numeric_limits<float>::infinity());
2515+
}
2516+
}
2517+
}
2518+
}
2519+
24992520
for (size_t m = q_start; m < q_end; m++) {
25002521
// apply attention mask & sofmax
25012522
auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1);
@@ -2563,6 +2584,11 @@ struct MHAHelper {
25632584

25642585
// for each weight block, loop through all value block
25652586
for (size_t v_blk = 0; v_blk < cur_kv_len_blocks; v_blk++) {
2587+
// sparse attention mask filtering for value blocks
2588+
if (!sparse_attention_mask.empty() &&
2589+
!sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk, v_blk)[0]) {
2590+
continue;
2591+
}
25662592
DATA_TYPE* v_ptr = nullptr;
25672593
if (q_is_xf16 || !q_cache_is_same) {
25682594
v_ptr = wv_scratch_b.ptr<DATA_TYPE>(v_blk, hk);
@@ -2762,14 +2788,21 @@ struct MHAHelper {
27622788
size_t q_len,
27632789
size_t cur_kv_len,
27642790
const PlainTensor& alibi_slopes,
2765-
float* score_output) {
2791+
float* score_output,
2792+
const std::vector<PlainTensor>& sparse_attention_mask) {
27662793
# if defined(OPENVINO_ARCH_X86_64)
27672794
if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
27682795
_gemv->tile_config();
27692796
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
27702797
auto block_number = block_table[i];
27712798
for (size_t pq = 0; pq < q_len; pq++) {
2799+
size_t q_blk = pq / _block_size;
27722800
for (size_t h = hq_beg; h < hq_end; h++) {
2801+
size_t k_blk = pk / _block_size;
2802+
// 只处理 mask==true 的 block
2803+
if (!sparse_attention_mask[0].ptr<bool>(h, q_blk, k_blk)[0]) {
2804+
continue;
2805+
}
27732806
(*_gemv)(
27742807
query.ptr<DATA_TYPE>(h, pq),
27752808
present_key.ptr<typename ov::element_type_traits<KEY_PREC>::value_type>(block_number, hk),
@@ -2783,7 +2816,13 @@ struct MHAHelper {
27832816
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
27842817
auto block_number = block_table[i];
27852818
for (size_t pq = 0; pq < q_len; pq++) {
2819+
size_t q_blk = pq / _block_size;
27862820
for (size_t h = hq_beg; h < hq_end; h++) {
2821+
size_t k_blk = pk / _block_size;
2822+
// 只处理 mask==true 的 block
2823+
if (!sparse_attention_mask[0].ptr<bool>(h, q_blk, k_blk)[0]) {
2824+
continue;
2825+
}
27872826
if constexpr (KEY_PREC == ov::element::u8 || KEY_PREC == ov::element::u4) {
27882827
dot_product_block_quantized<DATA_TYPE, KEY_PREC>(
27892828
query.ptr<DATA_TYPE>(h, pq),
@@ -2811,6 +2850,7 @@ struct MHAHelper {
28112850
# endif
28122851

28132852
for (size_t pq = 0; pq < q_len; pq++) {
2853+
size_t q_blk = pq / _block_size;
28142854
for (size_t h = hq_beg; h < hq_end; h++) {
28152855
// apply attention mask & sofmax
28162856
float* alibi_lookup = nullptr;
@@ -2819,6 +2859,14 @@ struct MHAHelper {
28192859
alibi_slope = alibi_slopes.ptr<float>()[h];
28202860
alibi_lookup = _alibi_lookup.ptr<float>() + _alibi_lookup.m_dims[0] - cur_kv_len;
28212861
}
2862+
// mask: softmax前将mask==false的位置赋值为-inf
2863+
// TODO: parallel process outside q loop
2864+
for (size_t k = 0; k < cur_kv_len; ++k) {
2865+
size_t k_blk = k / _block_size;
2866+
if (!sparse_attention_mask[0].ptr<bool>(h, q_blk, k_blk)[0]) {
2867+
_weight.ptr<float>(ithr, h - hq_beg, pq)[k] = -std::numeric_limits<float>::infinity();
2868+
}
2869+
}
28222870
attn_softmax_kernel<float>(_weight.ptr<float>(ithr, h - hq_beg, pq),
28232871
_weight.ptr<float>(ithr, h - hq_beg, pq),
28242872
_d_scale,
@@ -2845,7 +2893,13 @@ struct MHAHelper {
28452893
for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) {
28462894
auto block_number = block_table[i];
28472895
for (size_t pq = 0; pq < q_len; pq++) {
2896+
size_t q_blk = pq / _block_size;
28482897
for (size_t h = hq_beg; h < hq_end; h++) {
2898+
size_t k_blk = pv / _block_size;
2899+
// 只处理 mask==true 的 block
2900+
if (!sparse_attention_mask[0].ptr<bool>(h, q_blk, k_blk)[0]) {
2901+
continue;
2902+
}
28492903
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
28502904
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
28512905
_output.ptr<float>(ithr, pq, h),
@@ -3211,7 +3265,8 @@ struct MHA {
32113265
const PlainTensor& block_indices,
32123266
const PlainTensor& block_indices_begins,
32133267
const PlainTensor& alibi_slopes,
3214-
const PlainTensor& score_aggregation_window) {
3268+
const PlainTensor& score_aggregation_window,
3269+
const std::vector<PlainTensor>& sparse_attention_mask) {
32153270
auto Hk = v_cache.m_dims[1];
32163271

32173272
constexpr bool q_is_xf16 = one_of(precision_of<DATA_TYPE>::value, ov::element::bf16, ov::element::f16);
@@ -3361,7 +3416,8 @@ struct MHA {
33613416
1UL,
33623417
cur_kv_len,
33633418
alibi_slopes,
3364-
score_output);
3419+
score_output,
3420+
sparse_attention_mask);
33653421
} else {
33663422
const auto batch_in_reorder = item.batch_in_reorder;
33673423
const auto q_blk = item.q_block_id;
@@ -3431,7 +3487,9 @@ struct MHA {
34313487
alibi_slopes,
34323488
score_output,
34333489
q_start_idx_score,
3434-
score_info_ptr);
3490+
score_info_ptr,
3491+
batch_in_seq,
3492+
sparse_attention_mask);
34353493
}
34363494
# else
34373495
_helper.exec_kernel_multiple(
@@ -3452,8 +3510,10 @@ struct MHA {
34523510
alibi_slopes,
34533511
score_output,
34543512
q_start_idx_score,
3455-
score_info_ptr);
3456-
# endif
3513+
score_info_ptr,
3514+
batch_in_seq,
3515+
sparse_attention_mask);
3516+
# endif
34573517
}
34583518
});
34593519
if (output_score) {
@@ -3489,41 +3549,55 @@ struct MHA {
34893549
const PlainTensor& block_indices,
34903550
const PlainTensor& block_indices_begins,
34913551
const PlainTensor& alibi_slopes,
3492-
const PlainTensor& score_aggregation_window) {
3552+
const PlainTensor& score_aggregation_window,
3553+
const std::vector<PlainTensor>& sparse_attention_mask) {
34933554
_workitems.reset(query, past_lens, subsequence_begins, _helper._block_size);
34943555
if (output_score) {
34953556
_helper.init_score_buffers(past_lens, subsequence_begins, score_aggregation_window);
34963557
}
34973558

34983559
auto nthr = static_cast<size_t>(parallel_get_max_threads());
34993560

3500-
if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
3501-
exec_loop_mixed(query,
3502-
present_key,
3503-
present_value,
3504-
output_emb,
3505-
output_score,
3506-
max_context_len,
3507-
past_lens,
3508-
subsequence_begins,
3509-
block_indices,
3510-
block_indices_begins,
3511-
alibi_slopes,
3512-
score_aggregation_window);
3513-
} else {
3514-
_helper.exec_loop_bhl(query,
3515-
present_key,
3516-
present_value,
3517-
output_emb,
3518-
output_score,
3519-
max_context_len,
3520-
past_lens,
3521-
subsequence_begins,
3522-
block_indices,
3523-
block_indices_begins,
3524-
alibi_slopes,
3525-
score_aggregation_window);
3526-
}
3561+
// if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
3562+
// exec_loop_mixed(query,
3563+
// present_key,
3564+
// present_value,
3565+
// output_emb,
3566+
// output_score,
3567+
// max_context_len,
3568+
// past_lens,
3569+
// subsequence_begins,
3570+
// block_indices,
3571+
// block_indices_begins,
3572+
// alibi_slopes,
3573+
// score_aggregation_window);
3574+
// } else {
3575+
// _helper.exec_loop_bhl(query,
3576+
// present_key,
3577+
// present_value,
3578+
// output_emb,
3579+
// output_score,
3580+
// max_context_len,
3581+
// past_lens,
3582+
// subsequence_begins,
3583+
// block_indices,
3584+
// block_indices_begins,
3585+
// alibi_slopes,
3586+
// score_aggregation_window);
3587+
// }
3588+
exec_loop_mixed(query,
3589+
present_key,
3590+
present_value,
3591+
output_emb,
3592+
output_score,
3593+
max_context_len,
3594+
past_lens,
3595+
subsequence_begins,
3596+
block_indices,
3597+
block_indices_begins,
3598+
alibi_slopes,
3599+
score_aggregation_window,
3600+
sparse_attention_mask);
35273601
}
35283602
};
35293603

@@ -3565,7 +3639,8 @@ struct AttentionExecutor : public PagedAttentionExecutor {
35653639
PlainTensor& rotation_deltas,
35663640
PlainTensor& rotation_trig_lut,
35673641
PlainTensor& output_emb,
3568-
PlainTensor& output_score) {
3642+
PlainTensor& output_score,
3643+
std::vector<PlainTensor>& sparse_attention_mask) {
35693644
q.reset(inputs[ID_Q]); // [B_token, H * S]
35703645
k.reset(inputs[ID_K]);
35713646
v.reset(inputs[ID_V]);
@@ -3741,6 +3816,96 @@ struct AttentionExecutor : public PagedAttentionExecutor {
37413816
// TODO: enable block_size to be multiple of 32
37423817
OPENVINO_ASSERT(block_size == 32, "CPU: block size must be 32, current: ", block_size);
37433818

3819+
// --- 创建和初始化 sparse_attention_mask ---
3820+
sparse_attention_mask.clear();
3821+
size_t k_blocks = div_up(max_context_len, block_size);
3822+
for (size_t b = 0; b < B_seq; ++b) {
3823+
auto q_len_for_batch = subsequence_begins.ptr<int32_t>()[b + 1] - subsequence_begins.ptr<int32_t>()[b];
3824+
size_t q_blocks = div_up(static_cast<size_t>(q_len_for_batch), block_size);
3825+
PlainTensor mask;
3826+
mask.resize<bool>({H, q_blocks, k_blocks});
3827+
// 默认全部初始化为 false
3828+
std::memset(mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
3829+
// 可选:全部激活(示例)
3830+
for (size_t h = 0; h < H; ++h) {
3831+
for (size_t q_blk = 0; q_blk < q_blocks; ++q_blk) {
3832+
for (size_t k_blk = 0; k_blk < k_blocks; ++k_blk) {
3833+
// 对称点 (q_blocks/2, k_blocks/2),左上和右下为true,其余为false
3834+
// bool left_top = (q_blk < q_blocks / 2) && (k_blk < k_blocks / 2);
3835+
// bool right_bottom = (q_blk >= (q_blocks + 1) / 2) && (k_blk >= (k_blocks + 1) / 2);
3836+
// mask.ptr<bool>(h, q_blk, k_blk)[0] = left_top || right_bottom;
3837+
3838+
// 所有mask为true
3839+
mask.ptr<bool>(h, q_blk, k_blk)[0] = true;
3840+
3841+
// // 中间一个block设置为false
3842+
// if (q_blk == q_blocks / 2 && k_blk == k_blocks / 2) {
3843+
// mask.ptr<bool>(h, q_blk, k_blk)[0] = false;
3844+
// } else {
3845+
// mask.ptr<bool>(h, q_blk, k_blk)[0] = true;
3846+
// }
3847+
}
3848+
}
3849+
}
3850+
sparse_attention_mask.push_back(std::move(mask));
3851+
}
3852+
3853+
// --- 广播 sparse_attention_mask 以支持不同 block size ---
3854+
// 输入的 mask 可能是[h, q_blocks_orig, k_blocks_orig],需要广播到[h, q_blocks, k_blocks]
3855+
auto broadcast_sparse_attention_mask =
3856+
[](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
3857+
if (src_block_size == dst_block_size)
3858+
return;
3859+
if (src_block_size % dst_block_size != 0) {
3860+
OPENVINO_THROW("not supported 当sparse_attention_BlockSize=",
3861+
src_block_size,
3862+
" 但block_size=",
3863+
dst_block_size);
3864+
}
3865+
size_t scale = src_block_size / dst_block_size;
3866+
for (auto& mask : mask_vec) {
3867+
auto shape = mask.shape();
3868+
size_t H = shape[0];
3869+
size_t q_blocks_orig = shape[1];
3870+
size_t k_blocks_orig = shape[2];
3871+
size_t q_blocks = q_blocks_orig * scale;
3872+
size_t k_blocks = k_blocks_orig * scale;
3873+
PlainTensor new_mask;
3874+
new_mask.resize<bool>({H, q_blocks, k_blocks});
3875+
std::memset(new_mask.ptr<bool>(), 0, H * q_blocks * k_blocks * sizeof(bool));
3876+
for (size_t h = 0; h < H; ++h) {
3877+
for (size_t q_blk = 0; q_blk < q_blocks_orig; ++q_blk) {
3878+
for (size_t k_blk = 0; k_blk < k_blocks_orig; ++k_blk) {
3879+
bool val = mask.ptr<bool>(h, q_blk, k_blk)[0];
3880+
for (size_t dq = 0; dq < scale; ++dq) {
3881+
for (size_t dk = 0; dk < scale; ++dk) {
3882+
new_mask.ptr<bool>(h, q_blk * scale + dq, k_blk * scale + dk)[0] = val;
3883+
}
3884+
}
3885+
}
3886+
}
3887+
}
3888+
mask = std::move(new_mask);
3889+
}
3890+
};
3891+
// 原始sparse attention mask的block_size,后续通过Page Attention Node参数指定
3892+
// const size_t sparse_attention_BlockSize = 128;
3893+
const size_t sparse_attention_BlockSize = 32;
3894+
// 只支持 block_size <= sparse_attention_BlockSize 且 sparse_attention_BlockSize 是 block_size 的整数倍
3895+
if (block_size != sparse_attention_BlockSize) {
3896+
if (block_size > sparse_attention_BlockSize) {
3897+
OPENVINO_THROW("not supported: block_size > sparse_attention_BlockSize");
3898+
}
3899+
if (sparse_attention_BlockSize % block_size != 0) {
3900+
OPENVINO_THROW("not supported: sparse_attention_BlockSize ",
3901+
sparse_attention_BlockSize,
3902+
" 不是 block_size ",
3903+
block_size,
3904+
" 的整数倍");
3905+
}
3906+
broadcast_sparse_attention_mask(sparse_attention_mask, sparse_attention_BlockSize, block_size);
3907+
}
3908+
37443909
_helper.init(H,
37453910
S,
37463911
SV,
@@ -3823,6 +3988,10 @@ struct AttentionExecutor : public PagedAttentionExecutor {
38233988
PlainTensor output_emb;
38243989
PlainTensor output_score;
38253990

3991+
std::vector<PlainTensor>
3992+
sparse_attention_mask; // 每个vector对应一个batch,每个PlainTensor对应一个batch,格式:[H,
3993+
// q_blocks, k_blocks], bool类型
3994+
38263995
init(inputs,
38273996
outputs,
38283997
q,
@@ -3843,7 +4012,8 @@ struct AttentionExecutor : public PagedAttentionExecutor {
38434012
rotation_deltas,
38444013
rotation_trig_lut,
38454014
output_emb,
3846-
output_score);
4015+
output_score,
4016+
sparse_attention_mask);
38474017

38484018
if (rotated_block_indices) {
38494019
// Rotate kv cache currently doesn't support quantized cache.
@@ -3869,7 +4039,8 @@ struct AttentionExecutor : public PagedAttentionExecutor {
38694039
block_indices,
38704040
block_indices_begins,
38714041
alibi_slopes,
3872-
score_aggregation_window);
4042+
score_aggregation_window,
4043+
sparse_attention_mask);
38734044
}
38744045
};
38754046
#endif

0 commit comments

Comments
 (0)