|
| 1 | +#include <gtest/gtest.h> |
| 2 | + |
| 3 | +#include <cstddef> |
| 4 | +#include <cstdint> |
| 5 | +#include <memory> |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +#include "common_test_utils/include/common_test_utils/data_utils.hpp" |
| 9 | +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" |
| 10 | +#include "internal_properties.hpp" |
| 11 | +#include "openvino/core/dimension.hpp" |
| 12 | +#include "openvino/core/model.hpp" |
| 13 | +#include "openvino/core/partial_shape.hpp" |
| 14 | +#include "openvino/core/type/element_type.hpp" |
| 15 | +#include "openvino/op/assign.hpp" |
| 16 | +#include "openvino/op/concat.hpp" |
| 17 | +#include "openvino/op/constant.hpp" |
| 18 | +#include "openvino/op/gather.hpp" |
| 19 | +#include "openvino/op/parameter.hpp" |
| 20 | +#include "openvino/op/read_value.hpp" |
| 21 | +#include "openvino/op/scaled_dot_product_attention.hpp" |
| 22 | +#include "openvino/op/util/variable.hpp" |
| 23 | +#include "openvino/opsets/opset13.hpp" |
| 24 | +#include "openvino/pass/manager.hpp" |
| 25 | +#include "shared_test_classes/base/ov_subgraph.hpp" |
| 26 | +#include "utils/cpu_test_utils.hpp" |
| 27 | + |
| 28 | +using namespace ov::test; |
| 29 | +using namespace CPUTestUtils; |
| 30 | + |
| 31 | +namespace ov { |
| 32 | +namespace test { |
| 33 | + |
| 34 | +namespace { |
| 35 | + |
| 36 | +class StatefulSdpaBoolMaskTest : public ov::test::SubgraphBaseTest, public CPUTestsBase { |
| 37 | +protected: |
| 38 | + void SetUp() override { |
| 39 | + targetDevice = ov::test::utils::DEVICE_CPU; |
| 40 | + configuration[ov::hint::inference_precision.name()] = ov::element::bf16; |
| 41 | + configuration[ov::hint::kv_cache_precision.name()] = ov::element::bf16; |
| 42 | + rel_threshold = 0.02f; |
| 43 | + abs_threshold = 0.02f; |
| 44 | + selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::bf16); |
| 45 | + |
| 46 | + const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; |
| 47 | + const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; |
| 48 | + const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; |
| 49 | + const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}}; |
| 50 | + const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}}; |
| 51 | + const InputShape beam_shape{{-1}, {{1}}}; |
| 52 | + |
| 53 | + init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape}); |
| 54 | + |
| 55 | + auto q = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[0]); |
| 56 | + auto k = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[1]); |
| 57 | + auto v = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[2]); |
| 58 | + auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]); |
| 59 | + auto past_init = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[4]); |
| 60 | + auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]); |
| 61 | + |
| 62 | + q->set_friendly_name("q"); |
| 63 | + k->set_friendly_name("k"); |
| 64 | + v->set_friendly_name("v"); |
| 65 | + mask->set_friendly_name("attention_mask"); |
| 66 | + past_init->set_friendly_name("past_init"); |
| 67 | + beam_idx->set_friendly_name("beam_idx"); |
| 68 | + |
| 69 | + auto variable_k = std::make_shared<ov::op::util::Variable>( |
| 70 | + ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastk"}); |
| 71 | + auto variable_v = std::make_shared<ov::op::util::Variable>( |
| 72 | + ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastv"}); |
| 73 | + |
| 74 | + auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k); |
| 75 | + auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v); |
| 76 | + past_k->set_friendly_name("pastk_read"); |
| 77 | + past_v->set_friendly_name("pastv_read"); |
| 78 | + |
| 79 | + auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); |
| 80 | + auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis); |
| 81 | + auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis); |
| 82 | + gather_k->set_batch_dims(0); |
| 83 | + gather_v->set_batch_dims(0); |
| 84 | + |
| 85 | + auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2); |
| 86 | + auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2); |
| 87 | + |
| 88 | + auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false); |
| 89 | + sdpa->set_friendly_name("stateful_sdpa"); |
| 90 | + |
| 91 | + auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k); |
| 92 | + auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v); |
| 93 | + assign_k->set_friendly_name("pastk_write"); |
| 94 | + assign_v->set_friendly_name("pastv_write"); |
| 95 | + |
| 96 | + ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)}; |
| 97 | + ov::SinkVector sinks{assign_k, assign_v}; |
| 98 | + function = std::make_shared<ov::Model>(results, |
| 99 | + sinks, |
| 100 | + ov::ParameterVector{q, k, v, mask, past_init, beam_idx}, |
| 101 | + "StatefulSdpaBoolMask"); |
| 102 | + } |
| 103 | + |
| 104 | + void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override { |
| 105 | + inputs.clear(); |
| 106 | + |
| 107 | + const auto& parameters = function->get_parameters(); |
| 108 | + for (size_t idx = 0; idx < parameters.size(); ++idx) { |
| 109 | + const auto& param = parameters[idx]; |
| 110 | + const auto& shape = targetInputStaticShapes[idx]; |
| 111 | + if (param->get_element_type() == ov::element::bf16) { |
| 112 | + ov::Tensor tensor{ov::element::bf16, shape}; |
| 113 | + utils::fill_data_random(static_cast<ov::bfloat16*>(tensor.data()), tensor.get_size(), 2, -1, 10); |
| 114 | + inputs.insert({param, tensor}); |
| 115 | + } else if (param->get_element_type() == ov::element::boolean) { |
| 116 | + ov::Tensor tensor{ov::element::boolean, shape}; |
| 117 | + auto* data = tensor.data<bool>(); |
| 118 | + for (size_t i = 0; i < tensor.get_size(); ++i) { |
| 119 | + data[i] = (i % 3) != 0; |
| 120 | + } |
| 121 | + inputs.insert({param, tensor}); |
| 122 | + } else if (param->get_element_type() == ov::element::i32) { |
| 123 | + ov::Tensor tensor{ov::element::i32, shape}; |
| 124 | + auto* data = tensor.data<int32_t>(); |
| 125 | + int32_t denom = 1; |
| 126 | + if (!shape.empty() && shape[0] != 0) { |
| 127 | + denom = static_cast<int32_t>(shape[0]); |
| 128 | + } |
| 129 | + for (size_t i = 0; i < tensor.get_size(); ++i) { |
| 130 | + data[i] = static_cast<int32_t>(i % denom); |
| 131 | + } |
| 132 | + inputs.insert({param, tensor}); |
| 133 | + } else { |
| 134 | + FAIL() << "Unexpected parameter precision " << param->get_element_type(); |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | +}; |
| 139 | + |
| 140 | +TEST_F(StatefulSdpaBoolMaskTest, CompareWithRefs) { |
| 141 | + SKIP_IF_CURRENT_TEST_IS_DISABLED(); |
| 142 | + if (!ov::with_cpu_x86_bfloat16()) { |
| 143 | + GTEST_SKIP(); |
| 144 | + } |
| 145 | + run(); |
| 146 | + CheckPluginRelatedResults(compiledModel, "ScaledAttn"); |
| 147 | +} |
| 148 | + |
| 149 | +} // namespace |
| 150 | + |
| 151 | +} // namespace test |
| 152 | +} // namespace ov |
0 commit comments