Skip to content

Commit dcbfb2c

Browse files
committed
exec_loop_bhl support and broadcast_sparse_attention_mask support different block_size
1 parent 8156e70 commit dcbfb2c

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,7 +2950,8 @@ struct MHAHelper {
29502950
const PlainTensor& block_indices,
29512951
const PlainTensor& block_indices_begins,
29522952
const PlainTensor& alibi_slopes,
2953-
const PlainTensor& score_aggregation_window) {
2953+
const PlainTensor& score_aggregation_window,
2954+
const std::vector<PlainTensor>& sparse_attention_mask = {}) {
29542955
auto B = past_lens.size(0);
29552956
auto q_len = query.size(2);
29562957
auto kv_len_in_blocks = div_up(max_context_len, _block_size);
@@ -2995,13 +2996,20 @@ struct MHAHelper {
29952996

29962997
// kv_len must be valid
29972998
auto pk = pk_in_blocks * _block_size;
2999+
size_t k_blk = pk / _block_size;
29983000
if (pk < context_len) {
29993001
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pk_in_blocks];
30003002
# if defined(OPENVINO_ARCH_X86_64)
30013003
if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
30023004
_gemv->tile_config();
30033005
for (size_t pq = 0; pq < q_len; pq++) {
3006+
size_t q_blk = pq / _block_size;
30043007
for (size_t h = hq_beg; h < hq_end; h++) {
3008+
// sparse attention mask check
3009+
if (!sparse_attention_mask.empty() &&
3010+
!sparse_attention_mask[b].ptr<bool>(h, q_blk, k_blk)[0]) {
3011+
continue;
3012+
}
30053013
(*_gemv)(
30063014
query.ptr<DATA_TYPE>(b, h, pq),
30073015
key_cache.ptr<typename ov::element_type_traits<KEY_PREC>::value_type>(block_number, hk),
@@ -3012,7 +3020,13 @@ struct MHAHelper {
30123020
} else {
30133021
# endif
30143022
for (size_t pq = 0; pq < q_len; pq++) {
3023+
size_t q_blk = pq / _block_size;
30153024
for (size_t h = hq_beg; h < hq_end; h++) {
3025+
// sparse attention mask check
3026+
if (!sparse_attention_mask.empty() &&
3027+
!sparse_attention_mask[b].ptr<bool>(h, q_blk, k_blk)[0]) {
3028+
continue;
3029+
}
30163030
if constexpr (one_of(KEY_PREC, ov::element::u8, ov::element::u4)) {
30173031
dot_product_block_quantized<DATA_TYPE, KEY_PREC>(
30183032
query.ptr<DATA_TYPE>(b, h, pq),
@@ -3050,6 +3064,16 @@ struct MHAHelper {
30503064
alibi_slope = alibi_slopes.ptr<float>()[h];
30513065
alibi_lookup = _alibi_lookup.ptr<float>() + _alibi_lookup.m_dims[0] - cur_kv_len;
30523066
}
3067+
// sparse attention mask: mask==false 的位置赋值为-inf
3068+
if (!sparse_attention_mask.empty()) {
3069+
size_t q_blk = pq / _block_size;
3070+
for (size_t k = 0; k < cur_kv_len; ++k) {
3071+
size_t k_blk = k / _block_size;
3072+
if (!sparse_attention_mask[b].ptr<bool>(h, q_blk, k_blk)[0]) {
3073+
_weight_bhl.ptr<float>(b, h, pq)[k] = -std::numeric_limits<float>::infinity();
3074+
}
3075+
}
3076+
}
30533077
attn_softmax_kernel<float>(_weight_bhl.ptr<float>(b, h, pq),
30543078
_weight_bhl.ptr<float>(b, h, pq),
30553079
_d_scale,
@@ -3105,8 +3129,14 @@ struct MHAHelper {
31053129
// kv_len must be valid
31063130
if (pv < context_len) {
31073131
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pv_in_blocks];
3132+
size_t k_blk = pv / _block_size;
31083133
for (size_t pq = 0; pq < q_len; pq++) {
3134+
size_t q_blk = pq / _block_size;
31093135
for (size_t h = hq_beg; h < hq_end; h++) {
3136+
// sparse attention mask check
3137+
if (!sparse_attention_mask.empty() && !sparse_attention_mask[b].ptr<bool>(h, q_blk, k_blk)[0]) {
3138+
continue;
3139+
}
31103140
if constexpr (one_of(VALUE_PREC, ov::element::u8, ov::element::u4)) {
31113141
attn_acc_value_block_quantized<uint8_t, VALUE_PREC>(
31123142
_output_bhl.ptr<float>(b, pv_in_blocks, h, pq),
@@ -3558,46 +3588,48 @@ struct MHA {
35583588

35593589
auto nthr = static_cast<size_t>(parallel_get_max_threads());
35603590

3561-
// if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
3562-
// exec_loop_mixed(query,
3563-
// present_key,
3564-
// present_value,
3565-
// output_emb,
3566-
// output_score,
3567-
// max_context_len,
3568-
// past_lens,
3569-
// subsequence_begins,
3570-
// block_indices,
3571-
// block_indices_begins,
3572-
// alibi_slopes,
3573-
// score_aggregation_window);
3574-
// } else {
3575-
// _helper.exec_loop_bhl(query,
3576-
// present_key,
3577-
// present_value,
3578-
// output_emb,
3579-
// output_score,
3580-
// max_context_len,
3581-
// past_lens,
3582-
// subsequence_begins,
3583-
// block_indices,
3584-
// block_indices_begins,
3585-
// alibi_slopes,
3586-
// score_aggregation_window);
3587-
// }
3588-
exec_loop_mixed(query,
3589-
present_key,
3590-
present_value,
3591-
output_emb,
3592-
output_score,
3593-
max_context_len,
3594-
past_lens,
3595-
subsequence_begins,
3596-
block_indices,
3597-
block_indices_begins,
3598-
alibi_slopes,
3599-
score_aggregation_window,
3600-
sparse_attention_mask);
3591+
if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
3592+
exec_loop_mixed(query,
3593+
present_key,
3594+
present_value,
3595+
output_emb,
3596+
output_score,
3597+
max_context_len,
3598+
past_lens,
3599+
subsequence_begins,
3600+
block_indices,
3601+
block_indices_begins,
3602+
alibi_slopes,
3603+
score_aggregation_window,
3604+
sparse_attention_mask);
3605+
} else {
3606+
_helper.exec_loop_bhl(query,
3607+
present_key,
3608+
present_value,
3609+
output_emb,
3610+
output_score,
3611+
max_context_len,
3612+
past_lens,
3613+
subsequence_begins,
3614+
block_indices,
3615+
block_indices_begins,
3616+
alibi_slopes,
3617+
score_aggregation_window,
3618+
sparse_attention_mask);
3619+
}
3620+
// exec_loop_mixed(query,
3621+
// present_key,
3622+
// present_value,
3623+
// output_emb,
3624+
// output_score,
3625+
// max_context_len,
3626+
// past_lens,
3627+
// subsequence_begins,
3628+
// block_indices,
3629+
// block_indices_begins,
3630+
// alibi_slopes,
3631+
// score_aggregation_window,
3632+
// sparse_attention_mask);
36013633
}
36023634
};
36033635

0 commit comments

Comments
 (0)