@@ -586,6 +586,7 @@ struct MHAHelper {
586586 PlainTensor _block_rotation_coefficient_scratch;
587587 // Block size used when generating sparse_attention_mask (0 means unspecified/equal to _block_size)
588588 size_t _sparse_mask_block_size = 0 ;
589+ bool _use_softmax_sparse_mask = false ;
589590
590591 MHAHelper () {
591592 _weight.resize <float >({size_t {1 }, size_t {1 }, size_t {1 }, size_t {1 }});
@@ -933,24 +934,12 @@ struct MHAHelper {
933934 }
934935
935936 // Handle sparse attention mask for sliding window
936- if (!sparse_attention_mask.empty ()) {
937- // Check _sparse_mask_block_size is a multiple of vector length for correct sparse_mask indexing
938- # if defined(HAVE_AVX512F)
939- constexpr size_t vec_len = vec_len_f32_avx512;
940- # elif defined(HAVE_AVX2)
941- constexpr size_t vec_len = vec_len_f32_avx2;
942- # elif defined(OPENVINO_ARCH_ARM64)
943- constexpr size_t vec_len = vec_len_f32_neon;
944- # else
945- constexpr size_t vec_len = 1 ;
946- # endif
947- if (_sparse_mask_block_size % vec_len == 0 ) {
948- // Get the original xattn_mask and calculate offset
949- auto * original_mask = reinterpret_cast <uint8_t *>(
950- sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk / sparse_scale));
951- size_t mask_start_offset = start_idx / _sparse_mask_block_size;
952- xattn_mask = original_mask + mask_start_offset;
953- }
937+ if (!sparse_attention_mask.empty () && _use_softmax_sparse_mask) {
938+ // Get the original xattn_mask and calculate offset
939+ auto * original_mask = reinterpret_cast <uint8_t *>(
940+ sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk / sparse_scale));
941+ size_t mask_start_offset = start_idx / _sparse_mask_block_size;
942+ xattn_mask = original_mask + mask_start_offset;
954943 }
955944
956945 attn_softmax_kernel<float >(score + start_idx,
@@ -977,23 +966,9 @@ struct MHAHelper {
977966 alibi_slope = alibi_slopes.ptr <float >()[h];
978967 alibi_lookup = _alibi_lookup.ptr <float >() + _alibi_lookup.m_dims [0 ] - ncausal;
979968 }
980- if (!sparse_attention_mask.empty ()) {
981- // Check _sparse_mask_block_size is a multiple of vector length for correct sparse_mask indexing
982- # if defined(HAVE_AVX512F)
983- constexpr size_t vec_len = vec_len_f32_avx512;
984- # elif defined(HAVE_AVX2)
985- constexpr size_t vec_len = vec_len_f32_avx2;
986- # elif defined(OPENVINO_ARCH_ARM64)
987- constexpr size_t vec_len = vec_len_f32_neon;
988- # else
989- constexpr size_t vec_len = 1 ;
990- # endif
991- if (_sparse_mask_block_size % vec_len == 0 ) {
992- xattn_mask = reinterpret_cast <uint8_t *>(
993- sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk / sparse_scale));
994- }
995- } else {
996- xattn_mask = nullptr ;
969+ if (!sparse_attention_mask.empty () && _use_softmax_sparse_mask) {
970+ xattn_mask = reinterpret_cast <uint8_t *>(
971+ sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk / sparse_scale));
997972 }
998973 attn_softmax_kernel<float >(score,
999974 reinterpret_cast <DATA_TYPE*>(score),
@@ -2184,6 +2159,19 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21842159
21852160 // keep original mask granularity; remember its block size for on-the-fly mapping
21862161 _helper._sparse_mask_block_size = xattention_block_size;
2162+ // Check sparse_mask_block_size is a multiple of vector length for correct sparse_mask indexing
2163+ # if defined(HAVE_AVX512F)
2164+ constexpr size_t vec_len = vec_len_f32_avx512;
2165+ # elif defined(HAVE_AVX2)
2166+ constexpr size_t vec_len = vec_len_f32_avx2;
2167+ # elif defined(OPENVINO_ARCH_ARM64)
2168+ constexpr size_t vec_len = vec_len_f32_neon;
2169+ # else
2170+ constexpr size_t vec_len = 1 ;
2171+ # endif
2172+ if (xattention_block_size % vec_len == 0 ) {
2173+ _helper._use_softmax_sparse_mask = true ;
2174+ }
21872175 }
21882176
21892177 _helper.init (H,
0 commit comments