@@ -415,10 +415,8 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
415415 auto m_blocks = (q_len + m_block_size - 1 ) / m_block_size;
416416 bool is_xf16 = any_of (precision_of<T>::value, ov::element::bf16 , ov::element::f16 );
417417 // packed k, v
418- ov::element::Type attn_mask_precision = ov::element::Type (precision_of<T>::value);
419- if (attention_mask) {
420- attn_mask_precision = attention_mask.get_precision ();
421- }
418+ auto attn_mask_precision =
419+ attention_mask ? attention_mask.get_precision () : ov::element::Type (precision_of<T>::value);
422420
423421 parallel_for2d (B, Hk, [&](size_t b, size_t h) {
424422 T* k_ptr = &present_key.at <T>({b, h, 0 , 0 });
@@ -480,8 +478,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
480478 if (sink_input) {
481479 sink = &sink_input.at <float >({b, h, m, 0 }, true );
482480 }
483- uint8_t * attn_mask_row =
484- attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
481+ uint8_t * attn_mask_row = attn_mask_ptr ? attn_mask_ptr + m * attn_mask_stride : nullptr ;
485482
486483 attn_softmax (reinterpret_cast <void *>(score),
487484 reinterpret_cast <T*>(score),
@@ -646,10 +643,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
646643 auto k_stride_s = present_key.stride (3 );
647644
648645 auto m_blocks = (q_len + m_block_size - 1 ) / m_block_size;
649- ov::element::Type attn_mask_precision = precision;
650- if (attention_mask) {
651- attn_mask_precision = attention_mask.get_precision ();
652- }
646+ auto attn_mask_precision = attention_mask ? attention_mask.get_precision () : precision;
653647
654648 parallel_for3d (B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
655649 auto m_start = m_blk * m_block_size;
@@ -709,8 +703,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
709703 for (size_t m = m_start; m < m_end; m++) {
710704 // apply attention mask & sofmax
711705 auto ncausal = auto_causal ? (kv_len - q_len + m + 1 ) : kv_len;
712- uint8_t * attn_mask_row =
713- attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
706+ uint8_t * attn_mask_row = attn_mask_ptr ? attn_mask_ptr + m * attn_mask_stride : nullptr ;
714707 attn_softmax (reinterpret_cast <void *>(qk + (m - m_start) * kv_len),
715708 qk + (m - m_start) * kv_len,
716709 d_scale,
0 commit comments