Skip to content

Commit e98c971

Browse files
committed
Refactored the StatefulSdpaBoolMask test into a shared class that builds an FP32 graph, forces runtime inference precision, and instantiates coverage for x64 (bf16/f16) and ARM (f16)
1 parent 4ffb9e7 commit e98c971

File tree

4 files changed

+223
-140
lines changed

4 files changed

+223
-140
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "custom/subgraph_tests/src/classes/stateful_sdpa_bool_mask.hpp"
6+
7+
#include <gtest/gtest.h>
8+
9+
namespace ov {
10+
namespace test {
11+
12+
namespace {
13+
14+
INSTANTIATE_TEST_SUITE_P(smoke_ARM_StatefulSdpaBoolMask,
15+
StatefulSdpaBoolMaskTest,
16+
::testing::Values(ov::element::f16),
17+
StatefulSdpaBoolMaskTest::getTestCaseName);
18+
19+
} // namespace
20+
21+
} // namespace test
22+
} // namespace ov
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "custom/subgraph_tests/src/classes/stateful_sdpa_bool_mask.hpp"
6+
7+
#include <gtest/gtest.h>
8+
9+
#include <cstddef>
10+
#include <cstdint>
11+
#include <memory>
12+
#include <sstream>
13+
#include <vector>
14+
15+
#include "common_test_utils/include/common_test_utils/data_utils.hpp"
16+
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
17+
#include "internal_properties.hpp"
18+
#include "openvino/core/model.hpp"
19+
#include "openvino/op/assign.hpp"
20+
#include "openvino/op/concat.hpp"
21+
#include "openvino/op/constant.hpp"
22+
#include "openvino/op/gather.hpp"
23+
#include "openvino/op/parameter.hpp"
24+
#include "openvino/op/read_value.hpp"
25+
#include "openvino/op/scaled_dot_product_attention.hpp"
26+
#include "openvino/op/util/variable.hpp"
27+
#include "openvino/opsets/opset13.hpp"
28+
#include "openvino/runtime/system_conf.hpp"
29+
30+
using namespace ov::test;
31+
using namespace CPUTestUtils;
32+
33+
namespace ov {
34+
namespace test {
35+
36+
std::string StatefulSdpaBoolMaskTest::getTestCaseName(const testing::TestParamInfo<StatefulSdpaBoolMaskParams>& obj) {
37+
std::ostringstream result;
38+
result << "InferenceType=" << obj.param.get_type_name();
39+
return result.str();
40+
}
41+
42+
void StatefulSdpaBoolMaskTest::SetUp() {
43+
targetDevice = ov::test::utils::DEVICE_CPU;
44+
45+
const auto inferencePrecision = GetParam();
46+
const auto graphPrecision = ov::element::f32;
47+
48+
configuration[ov::hint::inference_precision.name()] = inferencePrecision;
49+
configuration[ov::hint::kv_cache_precision.name()] = ov::element::f32;
50+
rel_threshold = 0.02f;
51+
abs_threshold = 0.02f;
52+
selectedType = makeSelectedTypeStr(getPrimitiveType(), inferencePrecision);
53+
54+
const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
55+
const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
56+
const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
57+
const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}};
58+
const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}};
59+
const InputShape beam_shape{{-1}, {{1}}};
60+
61+
init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape});
62+
63+
auto q = std::make_shared<ov::op::v0::Parameter>(graphPrecision, inputDynamicShapes[0]);
64+
auto k = std::make_shared<ov::op::v0::Parameter>(graphPrecision, inputDynamicShapes[1]);
65+
auto v = std::make_shared<ov::op::v0::Parameter>(graphPrecision, inputDynamicShapes[2]);
66+
auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]);
67+
auto past_init = std::make_shared<ov::op::v0::Parameter>(graphPrecision, inputDynamicShapes[4]);
68+
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]);
69+
70+
q->set_friendly_name("q");
71+
k->set_friendly_name("k");
72+
v->set_friendly_name("v");
73+
mask->set_friendly_name("attention_mask");
74+
past_init->set_friendly_name("past_init");
75+
beam_idx->set_friendly_name("beam_idx");
76+
77+
auto variable_k = std::make_shared<ov::op::util::Variable>(
78+
ov::op::util::VariableInfo{inputDynamicShapes[4], graphPrecision, "pastk"});
79+
auto variable_v = std::make_shared<ov::op::util::Variable>(
80+
ov::op::util::VariableInfo{inputDynamicShapes[4], graphPrecision, "pastv"});
81+
82+
auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k);
83+
auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v);
84+
past_k->set_friendly_name("pastk_read");
85+
past_v->set_friendly_name("pastv_read");
86+
87+
auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0});
88+
auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis);
89+
auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis);
90+
gather_k->set_batch_dims(0);
91+
gather_v->set_batch_dims(0);
92+
93+
auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2);
94+
auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2);
95+
96+
auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false);
97+
sdpa->set_friendly_name("stateful_sdpa");
98+
99+
auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k);
100+
auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v);
101+
assign_k->set_friendly_name("pastk_write");
102+
assign_v->set_friendly_name("pastv_write");
103+
104+
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)};
105+
ov::SinkVector sinks{assign_k, assign_v};
106+
function = std::make_shared<ov::Model>(results,
107+
sinks,
108+
ov::ParameterVector{q, k, v, mask, past_init, beam_idx},
109+
"StatefulSdpaBoolMask");
110+
}
111+
112+
void StatefulSdpaBoolMaskTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
113+
inputs.clear();
114+
115+
const auto& parameters = function->get_parameters();
116+
for (size_t idx = 0; idx < parameters.size(); ++idx) {
117+
const auto& param = parameters[idx];
118+
const auto& shape = targetInputStaticShapes[idx];
119+
if (param->get_element_type() == ov::element::f32) {
120+
ov::Tensor tensor{ov::element::f32, shape};
121+
utils::fill_data_random(static_cast<float*>(tensor.data()), tensor.get_size(), 2, -1, 10);
122+
inputs.insert({param, tensor});
123+
} else if (param->get_element_type() == ov::element::boolean) {
124+
ov::Tensor tensor{ov::element::boolean, shape};
125+
auto* data = tensor.data<bool>();
126+
for (size_t i = 0; i < tensor.get_size(); ++i) {
127+
data[i] = (i % 3) != 0;
128+
}
129+
inputs.insert({param, tensor});
130+
} else if (param->get_element_type() == ov::element::i32) {
131+
ov::Tensor tensor{ov::element::i32, shape};
132+
auto* data = tensor.data<int32_t>();
133+
int32_t denom = 1;
134+
if (!shape.empty() && shape[0] != 0) {
135+
denom = static_cast<int32_t>(shape[0]);
136+
}
137+
for (size_t i = 0; i < tensor.get_size(); ++i) {
138+
data[i] = static_cast<int32_t>(i % denom);
139+
}
140+
inputs.insert({param, tensor});
141+
} else {
142+
FAIL() << "Unexpected parameter precision " << param->get_element_type();
143+
}
144+
}
145+
}
146+
147+
TEST_P(StatefulSdpaBoolMaskTest, CompareWithRefs) {
148+
SKIP_IF_CURRENT_TEST_IS_DISABLED();
149+
150+
const auto inferencePrecision = GetParam();
151+
if (inferencePrecision == ov::element::bf16 && !ov::with_cpu_x86_bfloat16()) {
152+
GTEST_SKIP();
153+
}
154+
if (inferencePrecision == ov::element::f16) {
155+
if (!ov::with_cpu_x86_avx512_core_fp16() && !ov::with_cpu_neon_fp16()) {
156+
GTEST_SKIP();
157+
}
158+
}
159+
160+
run();
161+
CheckPluginRelatedResults(compiledModel, "ScaledAttn");
162+
}
163+
164+
} // namespace test
165+
} // namespace ov
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#pragma once
5+
6+
#include <string>
7+
8+
#include "shared_test_classes/base/ov_subgraph.hpp"
9+
#include "utils/cpu_test_utils.hpp"
10+
11+
using namespace CPUTestUtils;
12+
13+
namespace ov {
14+
namespace test {
15+
16+
using StatefulSdpaBoolMaskParams = ov::element::Type;
17+
18+
class StatefulSdpaBoolMaskTest : public testing::WithParamInterface<StatefulSdpaBoolMaskParams>,
19+
virtual public ov::test::SubgraphBaseTest,
20+
public CPUTestsBase {
21+
public:
22+
static std::string getTestCaseName(const testing::TestParamInfo<StatefulSdpaBoolMaskParams>& obj);
23+
24+
protected:
25+
void SetUp() override;
26+
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
27+
};
28+
29+
} // namespace test
30+
} // namespace ov

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/stateful_sdpa_bool_mask.cpp

Lines changed: 6 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -2,153 +2,19 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5-
#include <gtest/gtest.h>
6-
7-
#include <cstddef>
8-
#include <cstdint>
9-
#include <memory>
10-
#include <vector>
11-
12-
#include "common_test_utils/include/common_test_utils/data_utils.hpp"
13-
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
14-
#include "internal_properties.hpp"
15-
#include "openvino/core/dimension.hpp"
16-
#include "openvino/core/model.hpp"
17-
#include "openvino/core/partial_shape.hpp"
18-
#include "openvino/core/type/element_type.hpp"
19-
#include "openvino/op/assign.hpp"
20-
#include "openvino/op/concat.hpp"
21-
#include "openvino/op/constant.hpp"
22-
#include "openvino/op/gather.hpp"
23-
#include "openvino/op/parameter.hpp"
24-
#include "openvino/op/read_value.hpp"
25-
#include "openvino/op/scaled_dot_product_attention.hpp"
26-
#include "openvino/op/util/variable.hpp"
27-
#include "openvino/opsets/opset13.hpp"
28-
#include "openvino/pass/manager.hpp"
29-
#include "shared_test_classes/base/ov_subgraph.hpp"
30-
#include "utils/cpu_test_utils.hpp"
5+
#include "custom/subgraph_tests/src/classes/stateful_sdpa_bool_mask.hpp"
316

32-
using namespace ov::test;
33-
using namespace CPUTestUtils;
7+
#include <gtest/gtest.h>
348

359
namespace ov {
3610
namespace test {
3711

3812
namespace {
3913

40-
class StatefulSdpaBoolMaskTest : public ov::test::SubgraphBaseTest, public CPUTestsBase {
41-
protected:
42-
void SetUp() override {
43-
targetDevice = ov::test::utils::DEVICE_CPU;
44-
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
45-
configuration[ov::hint::kv_cache_precision.name()] = ov::element::bf16;
46-
rel_threshold = 0.02f;
47-
abs_threshold = 0.02f;
48-
selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::bf16);
49-
50-
const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
51-
const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
52-
const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
53-
const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}};
54-
const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}};
55-
const InputShape beam_shape{{-1}, {{1}}};
56-
57-
init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape});
58-
59-
auto q = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[0]);
60-
auto k = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[1]);
61-
auto v = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[2]);
62-
auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]);
63-
auto past_init = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[4]);
64-
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]);
65-
66-
q->set_friendly_name("q");
67-
k->set_friendly_name("k");
68-
v->set_friendly_name("v");
69-
mask->set_friendly_name("attention_mask");
70-
past_init->set_friendly_name("past_init");
71-
beam_idx->set_friendly_name("beam_idx");
72-
73-
auto variable_k = std::make_shared<ov::op::util::Variable>(
74-
ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastk"});
75-
auto variable_v = std::make_shared<ov::op::util::Variable>(
76-
ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastv"});
77-
78-
auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k);
79-
auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v);
80-
past_k->set_friendly_name("pastk_read");
81-
past_v->set_friendly_name("pastv_read");
82-
83-
auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0});
84-
auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis);
85-
auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis);
86-
gather_k->set_batch_dims(0);
87-
gather_v->set_batch_dims(0);
88-
89-
auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2);
90-
auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2);
91-
92-
auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false);
93-
sdpa->set_friendly_name("stateful_sdpa");
94-
95-
auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k);
96-
auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v);
97-
assign_k->set_friendly_name("pastk_write");
98-
assign_v->set_friendly_name("pastv_write");
99-
100-
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)};
101-
ov::SinkVector sinks{assign_k, assign_v};
102-
function = std::make_shared<ov::Model>(results,
103-
sinks,
104-
ov::ParameterVector{q, k, v, mask, past_init, beam_idx},
105-
"StatefulSdpaBoolMask");
106-
}
107-
108-
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
109-
inputs.clear();
110-
111-
const auto& parameters = function->get_parameters();
112-
for (size_t idx = 0; idx < parameters.size(); ++idx) {
113-
const auto& param = parameters[idx];
114-
const auto& shape = targetInputStaticShapes[idx];
115-
if (param->get_element_type() == ov::element::bf16) {
116-
ov::Tensor tensor{ov::element::bf16, shape};
117-
utils::fill_data_random(static_cast<ov::bfloat16*>(tensor.data()), tensor.get_size(), 2, -1, 10);
118-
inputs.insert({param, tensor});
119-
} else if (param->get_element_type() == ov::element::boolean) {
120-
ov::Tensor tensor{ov::element::boolean, shape};
121-
auto* data = tensor.data<bool>();
122-
for (size_t i = 0; i < tensor.get_size(); ++i) {
123-
data[i] = (i % 3) != 0;
124-
}
125-
inputs.insert({param, tensor});
126-
} else if (param->get_element_type() == ov::element::i32) {
127-
ov::Tensor tensor{ov::element::i32, shape};
128-
auto* data = tensor.data<int32_t>();
129-
int32_t denom = 1;
130-
if (!shape.empty() && shape[0] != 0) {
131-
denom = static_cast<int32_t>(shape[0]);
132-
}
133-
for (size_t i = 0; i < tensor.get_size(); ++i) {
134-
data[i] = static_cast<int32_t>(i % denom);
135-
}
136-
inputs.insert({param, tensor});
137-
} else {
138-
FAIL() << "Unexpected parameter precision " << param->get_element_type();
139-
}
140-
}
141-
}
142-
};
143-
144-
TEST_F(StatefulSdpaBoolMaskTest, CompareWithRefs) {
145-
SKIP_IF_CURRENT_TEST_IS_DISABLED();
146-
if (!ov::with_cpu_x86_bfloat16()) {
147-
GTEST_SKIP();
148-
}
149-
run();
150-
CheckPluginRelatedResults(compiledModel, "ScaledAttn");
151-
}
14+
INSTANTIATE_TEST_SUITE_P(smoke_StatefulSdpaBoolMask,
15+
StatefulSdpaBoolMaskTest,
16+
::testing::Values(ov::element::bf16, ov::element::f16),
17+
StatefulSdpaBoolMaskTest::getTestCaseName);
15218

15319
} // namespace
15420

0 commit comments

Comments
 (0)