|
2 | 2 | // SPDX-License-Identifier: Apache-2.0 |
3 | 3 | // |
4 | 4 |
|
5 | | -#include <gtest/gtest.h> |
6 | | - |
7 | | -#include <cstddef> |
8 | | -#include <cstdint> |
9 | | -#include <memory> |
10 | | -#include <vector> |
11 | | - |
12 | | -#include "common_test_utils/include/common_test_utils/data_utils.hpp" |
13 | | -#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" |
14 | | -#include "internal_properties.hpp" |
15 | | -#include "openvino/core/dimension.hpp" |
16 | | -#include "openvino/core/model.hpp" |
17 | | -#include "openvino/core/partial_shape.hpp" |
18 | | -#include "openvino/core/type/element_type.hpp" |
19 | | -#include "openvino/op/assign.hpp" |
20 | | -#include "openvino/op/concat.hpp" |
21 | | -#include "openvino/op/constant.hpp" |
22 | | -#include "openvino/op/gather.hpp" |
23 | | -#include "openvino/op/parameter.hpp" |
24 | | -#include "openvino/op/read_value.hpp" |
25 | | -#include "openvino/op/scaled_dot_product_attention.hpp" |
26 | | -#include "openvino/op/util/variable.hpp" |
27 | | -#include "openvino/opsets/opset13.hpp" |
28 | | -#include "openvino/pass/manager.hpp" |
29 | | -#include "shared_test_classes/base/ov_subgraph.hpp" |
30 | | -#include "utils/cpu_test_utils.hpp" |
| 5 | +#include "custom/subgraph_tests/src/classes/stateful_sdpa_bool_mask.hpp" |
31 | 6 |
|
32 | | -using namespace ov::test; |
33 | | -using namespace CPUTestUtils; |
| 7 | +#include <gtest/gtest.h> |
34 | 8 |
|
35 | 9 | namespace ov { |
36 | 10 | namespace test { |
37 | 11 |
|
38 | 12 | namespace { |
39 | 13 |
|
40 | | -class StatefulSdpaBoolMaskTest : public ov::test::SubgraphBaseTest, public CPUTestsBase { |
41 | | -protected: |
42 | | - void SetUp() override { |
43 | | - targetDevice = ov::test::utils::DEVICE_CPU; |
44 | | - configuration[ov::hint::inference_precision.name()] = ov::element::bf16; |
45 | | - configuration[ov::hint::kv_cache_precision.name()] = ov::element::bf16; |
46 | | - rel_threshold = 0.02f; |
47 | | - abs_threshold = 0.02f; |
48 | | - selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::bf16); |
49 | | - |
50 | | - const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; |
51 | | - const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; |
52 | | - const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; |
53 | | - const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}}; |
54 | | - const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}}; |
55 | | - const InputShape beam_shape{{-1}, {{1}}}; |
56 | | - |
57 | | - init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape}); |
58 | | - |
59 | | - auto q = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[0]); |
60 | | - auto k = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[1]); |
61 | | - auto v = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[2]); |
62 | | - auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]); |
63 | | - auto past_init = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[4]); |
64 | | - auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]); |
65 | | - |
66 | | - q->set_friendly_name("q"); |
67 | | - k->set_friendly_name("k"); |
68 | | - v->set_friendly_name("v"); |
69 | | - mask->set_friendly_name("attention_mask"); |
70 | | - past_init->set_friendly_name("past_init"); |
71 | | - beam_idx->set_friendly_name("beam_idx"); |
72 | | - |
73 | | - auto variable_k = std::make_shared<ov::op::util::Variable>( |
74 | | - ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastk"}); |
75 | | - auto variable_v = std::make_shared<ov::op::util::Variable>( |
76 | | - ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastv"}); |
77 | | - |
78 | | - auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k); |
79 | | - auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v); |
80 | | - past_k->set_friendly_name("pastk_read"); |
81 | | - past_v->set_friendly_name("pastv_read"); |
82 | | - |
83 | | - auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); |
84 | | - auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis); |
85 | | - auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis); |
86 | | - gather_k->set_batch_dims(0); |
87 | | - gather_v->set_batch_dims(0); |
88 | | - |
89 | | - auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2); |
90 | | - auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2); |
91 | | - |
92 | | - auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false); |
93 | | - sdpa->set_friendly_name("stateful_sdpa"); |
94 | | - |
95 | | - auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k); |
96 | | - auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v); |
97 | | - assign_k->set_friendly_name("pastk_write"); |
98 | | - assign_v->set_friendly_name("pastv_write"); |
99 | | - |
100 | | - ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)}; |
101 | | - ov::SinkVector sinks{assign_k, assign_v}; |
102 | | - function = std::make_shared<ov::Model>(results, |
103 | | - sinks, |
104 | | - ov::ParameterVector{q, k, v, mask, past_init, beam_idx}, |
105 | | - "StatefulSdpaBoolMask"); |
106 | | - } |
107 | | - |
108 | | - void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override { |
109 | | - inputs.clear(); |
110 | | - |
111 | | - const auto& parameters = function->get_parameters(); |
112 | | - for (size_t idx = 0; idx < parameters.size(); ++idx) { |
113 | | - const auto& param = parameters[idx]; |
114 | | - const auto& shape = targetInputStaticShapes[idx]; |
115 | | - if (param->get_element_type() == ov::element::bf16) { |
116 | | - ov::Tensor tensor{ov::element::bf16, shape}; |
117 | | - utils::fill_data_random(static_cast<ov::bfloat16*>(tensor.data()), tensor.get_size(), 2, -1, 10); |
118 | | - inputs.insert({param, tensor}); |
119 | | - } else if (param->get_element_type() == ov::element::boolean) { |
120 | | - ov::Tensor tensor{ov::element::boolean, shape}; |
121 | | - auto* data = tensor.data<bool>(); |
122 | | - for (size_t i = 0; i < tensor.get_size(); ++i) { |
123 | | - data[i] = (i % 3) != 0; |
124 | | - } |
125 | | - inputs.insert({param, tensor}); |
126 | | - } else if (param->get_element_type() == ov::element::i32) { |
127 | | - ov::Tensor tensor{ov::element::i32, shape}; |
128 | | - auto* data = tensor.data<int32_t>(); |
129 | | - int32_t denom = 1; |
130 | | - if (!shape.empty() && shape[0] != 0) { |
131 | | - denom = static_cast<int32_t>(shape[0]); |
132 | | - } |
133 | | - for (size_t i = 0; i < tensor.get_size(); ++i) { |
134 | | - data[i] = static_cast<int32_t>(i % denom); |
135 | | - } |
136 | | - inputs.insert({param, tensor}); |
137 | | - } else { |
138 | | - FAIL() << "Unexpected parameter precision " << param->get_element_type(); |
139 | | - } |
140 | | - } |
141 | | - } |
142 | | -}; |
143 | | - |
144 | | -TEST_F(StatefulSdpaBoolMaskTest, CompareWithRefs) { |
145 | | - SKIP_IF_CURRENT_TEST_IS_DISABLED(); |
146 | | - if (!ov::with_cpu_x86_bfloat16()) { |
147 | | - GTEST_SKIP(); |
148 | | - } |
149 | | - run(); |
150 | | - CheckPluginRelatedResults(compiledModel, "ScaledAttn"); |
151 | | -} |
| 14 | +INSTANTIATE_TEST_SUITE_P(smoke_StatefulSdpaBoolMask, |
| 15 | + StatefulSdpaBoolMaskTest, |
| 16 | + ::testing::Values(ov::element::bf16, ov::element::f16), |
| 17 | + StatefulSdpaBoolMaskTest::getTestCaseName); |
152 | 18 |
|
153 | 19 | } // namespace |
154 | 20 |
|
|
0 commit comments