Skip to content

Commit ec830a2

Browse files
committed
[CPU] Fix attention mask precision handling in ScaledDotProductAttention
Use actual attention mask input precision instead of compute precision (bf16/f16) to fix LFM2-350M output corruption when running with low precision on Xeon platforms.
1 parent f8bd7bd commit ec830a2

File tree

2 files changed

+177
-10
lines changed

2 files changed

+177
-10
lines changed

src/plugins/intel_cpu/src/nodes/scaled_attn.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,11 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
415415
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
416416
bool is_xf16 = any_of(precision_of<T>::value, ov::element::bf16, ov::element::f16);
417417
// packed k, v
418+
ov::element::Type attn_mask_precision = ov::element::Type(precision_of<T>::value);
419+
if (attention_mask) {
420+
attn_mask_precision = attention_mask.get_precision();
421+
}
422+
418423
parallel_for2d(B, Hk, [&](size_t b, size_t h) {
419424
T* k_ptr = &present_key.at<T>({b, h, 0, 0});
420425
T* v_ptr = &present_value.at<T>({b, h, 0, 0});
@@ -451,11 +456,12 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
451456
}
452457

453458
uint8_t* attn_mask_ptr = nullptr;
454-
auto attn_mask_stride = 0;
459+
size_t attn_mask_stride = 0;
455460
if (attention_mask) {
456-
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<T>({b, h, 0, 0}, true));
461+
const size_t mask_head = attention_mask.size(1) > 1 ? h : 0;
462+
attn_mask_ptr = static_cast<uint8_t*>(attention_mask.ptr_v(b, mask_head, 0, 0));
457463
if (attention_mask.size(2) > 1) {
458-
attn_mask_stride = attention_mask.stride(2) * sizeof(T);
464+
attn_mask_stride = attention_mask.stride_bytes(2);
459465
}
460466
}
461467
uint8_t* cmask_ptr = nullptr;
@@ -474,18 +480,20 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
474480
if (sink_input) {
475481
sink = &sink_input.at<float>({b, h, m, 0}, true);
476482
}
483+
uint8_t* attn_mask_row =
484+
attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
477485

478486
attn_softmax(reinterpret_cast<void*>(score),
479487
reinterpret_cast<T*>(score),
480488
d_scale,
481489
reinterpret_cast<void*>(alibi_ptr + m * alibi_stride),
482-
attn_mask_ptr + m * attn_mask_stride,
490+
attn_mask_row,
483491
cmask_ptr + m * cmask_stride,
484492
select_nfltmax_at_0,
485493
ncausal,
486494
kv_len,
487495
precision_of<T>::value,
488-
precision_of<T>::value,
496+
attn_mask_precision,
489497
precision_of<T>::value,
490498
sink);
491499
}
@@ -638,6 +646,10 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
638646
auto k_stride_s = present_key.stride(3);
639647

640648
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
649+
ov::element::Type attn_mask_precision = precision;
650+
if (attention_mask) {
651+
attn_mask_precision = attention_mask.get_precision();
652+
}
641653

642654
parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
643655
auto m_start = m_blk * m_block_size;
@@ -657,11 +669,12 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
657669
}
658670
}
659671
uint8_t* attn_mask_ptr = nullptr;
660-
auto attn_mask_stride = 0;
672+
size_t attn_mask_stride = 0;
661673
if (attention_mask) {
662-
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<T>({b, h, 0, 0}, true));
674+
const size_t mask_head = attention_mask.size(1) > 1 ? h : 0;
675+
attn_mask_ptr = static_cast<uint8_t*>(attention_mask.ptr_v(b, mask_head, 0, 0));
663676
if (attention_mask.size(2) > 1) {
664-
attn_mask_stride = attention_mask.stride(2) * sizeof(T);
677+
attn_mask_stride = attention_mask.stride_bytes(2);
665678
}
666679
}
667680
uint8_t* cmask_ptr = nullptr;
@@ -696,17 +709,19 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
696709
for (size_t m = m_start; m < m_end; m++) {
697710
// apply attention mask & sofmax
698711
auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len;
712+
uint8_t* attn_mask_row =
713+
attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
699714
attn_softmax(reinterpret_cast<void*>(qk + (m - m_start) * kv_len),
700715
qk + (m - m_start) * kv_len,
701716
d_scale,
702717
reinterpret_cast<void*>(alibi_ptr + m * alibi_stride),
703-
attn_mask_ptr + m * attn_mask_stride,
718+
attn_mask_row,
704719
cmask_ptr + m * cmask_stride,
705720
select_nfltmax_at_0,
706721
ncausal,
707722
kv_len,
708723
precision,
709-
precision,
724+
attn_mask_precision,
710725
precision,
711726
nullptr);
712727
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <cstddef>
4+
#include <cstdint>
5+
#include <memory>
6+
#include <vector>
7+
8+
#include "common_test_utils/include/common_test_utils/data_utils.hpp"
9+
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
10+
#include "internal_properties.hpp"
11+
#include "openvino/core/dimension.hpp"
12+
#include "openvino/core/model.hpp"
13+
#include "openvino/core/partial_shape.hpp"
14+
#include "openvino/core/type/element_type.hpp"
15+
#include "openvino/op/assign.hpp"
16+
#include "openvino/op/concat.hpp"
17+
#include "openvino/op/constant.hpp"
18+
#include "openvino/op/gather.hpp"
19+
#include "openvino/op/parameter.hpp"
20+
#include "openvino/op/read_value.hpp"
21+
#include "openvino/op/scaled_dot_product_attention.hpp"
22+
#include "openvino/op/util/variable.hpp"
23+
#include "openvino/opsets/opset13.hpp"
24+
#include "openvino/pass/manager.hpp"
25+
#include "shared_test_classes/base/ov_subgraph.hpp"
26+
#include "utils/cpu_test_utils.hpp"
27+
28+
using namespace ov::test;
29+
using namespace CPUTestUtils;
30+
31+
namespace ov {
32+
namespace test {
33+
34+
namespace {
35+
36+
class StatefulSdpaBoolMaskTest : public ov::test::SubgraphBaseTest, public CPUTestsBase {
37+
protected:
38+
void SetUp() override {
39+
targetDevice = ov::test::utils::DEVICE_CPU;
40+
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
41+
configuration[ov::hint::kv_cache_precision.name()] = ov::element::bf16;
42+
rel_threshold = 0.02f;
43+
abs_threshold = 0.02f;
44+
selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::bf16);
45+
46+
const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
47+
const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
48+
const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
49+
const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}};
50+
const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}};
51+
const InputShape beam_shape{{-1}, {{1}}};
52+
53+
init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape});
54+
55+
auto q = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[0]);
56+
auto k = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[1]);
57+
auto v = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[2]);
58+
auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]);
59+
auto past_init = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[4]);
60+
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]);
61+
62+
q->set_friendly_name("q");
63+
k->set_friendly_name("k");
64+
v->set_friendly_name("v");
65+
mask->set_friendly_name("attention_mask");
66+
past_init->set_friendly_name("past_init");
67+
beam_idx->set_friendly_name("beam_idx");
68+
69+
auto variable_k = std::make_shared<ov::op::util::Variable>(
70+
ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastk"});
71+
auto variable_v = std::make_shared<ov::op::util::Variable>(
72+
ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastv"});
73+
74+
auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k);
75+
auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v);
76+
past_k->set_friendly_name("pastk_read");
77+
past_v->set_friendly_name("pastv_read");
78+
79+
auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0});
80+
auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis);
81+
auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis);
82+
gather_k->set_batch_dims(0);
83+
gather_v->set_batch_dims(0);
84+
85+
auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2);
86+
auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2);
87+
88+
auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false);
89+
sdpa->set_friendly_name("stateful_sdpa");
90+
91+
auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k);
92+
auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v);
93+
assign_k->set_friendly_name("pastk_write");
94+
assign_v->set_friendly_name("pastv_write");
95+
96+
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)};
97+
ov::SinkVector sinks{assign_k, assign_v};
98+
function = std::make_shared<ov::Model>(results,
99+
sinks,
100+
ov::ParameterVector{q, k, v, mask, past_init, beam_idx},
101+
"StatefulSdpaBoolMask");
102+
}
103+
104+
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
105+
inputs.clear();
106+
107+
const auto& parameters = function->get_parameters();
108+
for (size_t idx = 0; idx < parameters.size(); ++idx) {
109+
const auto& param = parameters[idx];
110+
const auto& shape = targetInputStaticShapes[idx];
111+
if (param->get_element_type() == ov::element::bf16) {
112+
ov::Tensor tensor{ov::element::bf16, shape};
113+
utils::fill_data_random(static_cast<ov::bfloat16*>(tensor.data()), tensor.get_size(), 2, -1, 10);
114+
inputs.insert({param, tensor});
115+
} else if (param->get_element_type() == ov::element::boolean) {
116+
ov::Tensor tensor{ov::element::boolean, shape};
117+
auto* data = tensor.data<bool>();
118+
for (size_t i = 0; i < tensor.get_size(); ++i) {
119+
data[i] = (i % 3) != 0;
120+
}
121+
inputs.insert({param, tensor});
122+
} else if (param->get_element_type() == ov::element::i32) {
123+
ov::Tensor tensor{ov::element::i32, shape};
124+
auto* data = tensor.data<int32_t>();
125+
int32_t denom = 1;
126+
if (!shape.empty() && shape[0] != 0) {
127+
denom = static_cast<int32_t>(shape[0]);
128+
}
129+
for (size_t i = 0; i < tensor.get_size(); ++i) {
130+
data[i] = static_cast<int32_t>(i % denom);
131+
}
132+
inputs.insert({param, tensor});
133+
} else {
134+
FAIL() << "Unexpected parameter precision " << param->get_element_type();
135+
}
136+
}
137+
}
138+
};
139+
140+
TEST_F(StatefulSdpaBoolMaskTest, CompareWithRefs) {
141+
SKIP_IF_CURRENT_TEST_IS_DISABLED();
142+
if (!ov::with_cpu_x86_bfloat16()) {
143+
GTEST_SKIP();
144+
}
145+
run();
146+
CheckPluginRelatedResults(compiledModel, "ScaledAttn");
147+
}
148+
149+
} // namespace
150+
151+
} // namespace test
152+
} // namespace ov

0 commit comments

Comments
 (0)