Skip to content

Commit 7885657

Browse files
committed
optimize use_softmax_sparse_mask check
1 parent c0ab072 commit 7885657

File tree

1 file changed

+23
-35
lines changed

1 file changed

+23
-35
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ struct MHAHelper {
586586
PlainTensor _block_rotation_coefficient_scratch;
587587
// Block size used when generating sparse_attention_mask (0 means unspecified/equal to _block_size)
588588
size_t _sparse_mask_block_size = 0;
589+
bool _use_softmax_sparse_mask = false;
589590

590591
MHAHelper() {
591592
_weight.resize<float>({size_t{1}, size_t{1}, size_t{1}, size_t{1}});
@@ -933,24 +934,12 @@ struct MHAHelper {
933934
}
934935

935936
// Handle sparse attention mask for sliding window
936-
if (!sparse_attention_mask.empty()) {
937-
// Check _sparse_mask_block_size is a multiple of vector length for correct sparse_mask indexing
938-
# if defined(HAVE_AVX512F)
939-
constexpr size_t vec_len = vec_len_f32_avx512;
940-
# elif defined(HAVE_AVX2)
941-
constexpr size_t vec_len = vec_len_f32_avx2;
942-
# elif defined(OPENVINO_ARCH_ARM64)
943-
constexpr size_t vec_len = vec_len_f32_neon;
944-
# else
945-
constexpr size_t vec_len = 1;
946-
# endif
947-
if (_sparse_mask_block_size % vec_len == 0) {
948-
// Get the original xattn_mask and calculate offset
949-
auto* original_mask = reinterpret_cast<uint8_t*>(
950-
sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk / sparse_scale));
951-
size_t mask_start_offset = start_idx / _sparse_mask_block_size;
952-
xattn_mask = original_mask + mask_start_offset;
953-
}
937+
if (!sparse_attention_mask.empty() && _use_softmax_sparse_mask) {
938+
// Get the original xattn_mask and calculate offset
939+
auto* original_mask = reinterpret_cast<uint8_t*>(
940+
sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk / sparse_scale));
941+
size_t mask_start_offset = start_idx / _sparse_mask_block_size;
942+
xattn_mask = original_mask + mask_start_offset;
954943
}
955944

956945
attn_softmax_kernel<float>(score + start_idx,
@@ -977,23 +966,9 @@ struct MHAHelper {
977966
alibi_slope = alibi_slopes.ptr<float>()[h];
978967
alibi_lookup = _alibi_lookup.ptr<float>() + _alibi_lookup.m_dims[0] - ncausal;
979968
}
980-
if (!sparse_attention_mask.empty()) {
981-
// Check _sparse_mask_block_size is a multiple of vector length for correct sparse_mask indexing
982-
# if defined(HAVE_AVX512F)
983-
constexpr size_t vec_len = vec_len_f32_avx512;
984-
# elif defined(HAVE_AVX2)
985-
constexpr size_t vec_len = vec_len_f32_avx2;
986-
# elif defined(OPENVINO_ARCH_ARM64)
987-
constexpr size_t vec_len = vec_len_f32_neon;
988-
# else
989-
constexpr size_t vec_len = 1;
990-
# endif
991-
if (_sparse_mask_block_size % vec_len == 0) {
992-
xattn_mask = reinterpret_cast<uint8_t*>(
993-
sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk / sparse_scale));
994-
}
995-
} else {
996-
xattn_mask = nullptr;
969+
if (!sparse_attention_mask.empty() && _use_softmax_sparse_mask) {
970+
xattn_mask = reinterpret_cast<uint8_t*>(
971+
sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk / sparse_scale));
997972
}
998973
attn_softmax_kernel<float>(score,
999974
reinterpret_cast<DATA_TYPE*>(score),
@@ -2184,6 +2159,19 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21842159

21852160
// keep original mask granularity; remember its block size for on-the-fly mapping
21862161
_helper._sparse_mask_block_size = xattention_block_size;
2162+
// Check sparse_mask_block_size is a multiple of vector length for correct sparse_mask indexing
2163+
# if defined(HAVE_AVX512F)
2164+
constexpr size_t vec_len = vec_len_f32_avx512;
2165+
# elif defined(HAVE_AVX2)
2166+
constexpr size_t vec_len = vec_len_f32_avx2;
2167+
# elif defined(OPENVINO_ARCH_ARM64)
2168+
constexpr size_t vec_len = vec_len_f32_neon;
2169+
# else
2170+
constexpr size_t vec_len = 1;
2171+
# endif
2172+
if (xattention_block_size % vec_len == 0) {
2173+
_helper._use_softmax_sparse_mask = true;
2174+
}
21872175
}
21882176

21892177
_helper.init(H,

0 commit comments

Comments
 (0)