Skip to content

Commit eb9ea32

Browse files
committed
Apply review comments
1 parent ec830a2 commit eb9ea32

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

src/plugins/intel_cpu/src/nodes/scaled_attn.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,8 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
415415
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
416416
bool is_xf16 = any_of(precision_of<T>::value, ov::element::bf16, ov::element::f16);
417417
// packed k, v
418-
ov::element::Type attn_mask_precision = ov::element::Type(precision_of<T>::value);
419-
if (attention_mask) {
420-
attn_mask_precision = attention_mask.get_precision();
421-
}
418+
auto attn_mask_precision =
419+
attention_mask ? attention_mask.get_precision() : ov::element::Type(precision_of<T>::value);
422420

423421
parallel_for2d(B, Hk, [&](size_t b, size_t h) {
424422
T* k_ptr = &present_key.at<T>({b, h, 0, 0});
@@ -480,8 +478,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
480478
if (sink_input) {
481479
sink = &sink_input.at<float>({b, h, m, 0}, true);
482480
}
483-
uint8_t* attn_mask_row =
484-
attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
481+
uint8_t* attn_mask_row = attn_mask_ptr ? attn_mask_ptr + m * attn_mask_stride : nullptr;
485482

486483
attn_softmax(reinterpret_cast<void*>(score),
487484
reinterpret_cast<T*>(score),
@@ -646,10 +643,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
646643
auto k_stride_s = present_key.stride(3);
647644

648645
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
649-
ov::element::Type attn_mask_precision = precision;
650-
if (attention_mask) {
651-
attn_mask_precision = attention_mask.get_precision();
652-
}
646+
auto attn_mask_precision = attention_mask ? attention_mask.get_precision() : precision;
653647

654648
parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
655649
auto m_start = m_blk * m_block_size;
@@ -709,8 +703,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
709703
for (size_t m = m_start; m < m_end; m++) {
710704
// apply attention mask & sofmax
711705
auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len;
712-
uint8_t* attn_mask_row =
713-
attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
706+
uint8_t* attn_mask_row = attn_mask_ptr ? attn_mask_ptr + m * attn_mask_stride : nullptr;
714707
attn_softmax(reinterpret_cast<void*>(qk + (m - m_start) * kv_len),
715708
qk + (m - m_start) * kv_len,
716709
d_scale,

0 commit comments

Comments
 (0)