Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,34 +93,38 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query,
const ov::intel_cpu::CpuParallelPtr& cpu_parallel) {
size_t B = query.m_dims[0];
size_t T = query.m_dims[1];
size_t H = query.m_dims[2];
size_t qk_heads = query.m_dims[2];
size_t K = query.m_dims[3];
size_t v_heads = value.m_dims[2];
size_t V = value.m_dims[3];
const size_t K_HEAD_DIMS = K;
const size_t V_HEAD_DIMS = V;
const float q_scale = 1 / std::sqrt(static_cast<float>(K_HEAD_DIMS));
cpu_parallel->parallel_for3d(B, H, V, [&](size_t i_b, size_t i_h, size_t i_v) {
const size_t group_size = v_heads / qk_heads;
cpu_parallel->parallel_for3d(B, v_heads, V, [&](size_t i_b, size_t i_h, size_t i_v) {
size_t tid = parallel_get_thread_num();
float* init_state = temp_buffer + tid * 3 * K_HEAD_DIMS;
float* b_k = temp_buffer + tid * 3 * K_HEAD_DIMS + K_HEAD_DIMS;
float* b_q = temp_buffer + tid * 3 * K_HEAD_DIMS + 2 * K_HEAD_DIMS;
// B, T, H, K
float* q_ptr = query.ptr<float>(i_b, 0, i_h);
float* k_ptr = key.ptr<float>(i_b, 0, i_h);
const size_t hk = i_h / group_size;
// B, T, qk, K
float* q_ptr = query.ptr<float>(i_b, 0, hk);
float* k_ptr = key.ptr<float>(i_b, 0, hk);
// B, T, v_heads, V
float* v_ptr = value.ptr<float>(i_b, 0, i_h);
// B, H, K, V
// B, v_heads, K, V
for (size_t j = 0; j < K_HEAD_DIMS; j++) {
init_state[j] = recurrent_state.at<float>({i_b, i_h, j, i_v});
}

for (size_t i = 0; i < T; i++) {
// gate: B, T, H
// gate: B, T, v_heads
float b_g = gate.at<float>({i_b, i, i_h});
float b_beta = beta.at<float>({i_b, i, i_h});
b_g = exp(b_g);
for (size_t j = 0; j < K_HEAD_DIMS; j++) {
b_k[j] = k_ptr[i * H * K_HEAD_DIMS + j];
b_q[j] = q_ptr[i * H * K_HEAD_DIMS + j];
b_k[j] = k_ptr[i * qk_heads * K_HEAD_DIMS + j];
b_q[j] = q_ptr[i * qk_heads * K_HEAD_DIMS + j];
}
if (use_qk_l2norm) {
l2norm(b_k, K_HEAD_DIMS, k_l2_norm_eps);
Expand All @@ -130,8 +134,8 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query,
// h0 * gate
multiply_scalar(init_state, init_state, b_g, K_HEAD_DIMS);
float h_k = dot_product(init_state, b_k, K_HEAD_DIMS, nullptr, nullptr, nullptr, 0);
// B, T, H, V
float b_v = v_ptr[i_v + i * H * V_HEAD_DIMS];
// B, T, v_heads, V
float b_v = v_ptr[i_v + i * v_heads * V_HEAD_DIMS];
b_v -= h_k;
// b_v * b_k
b_v *= b_beta;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ std::vector<gated_delta_net_params> test_cases = {
{1, 39, 2, 2, 16, 32, ov::element::f32, "CPU"},
{2, 16, 2, 2, 16, 16, ov::element::f32, "CPU"},
{2, 39, 2, 2, 16, 16, ov::element::f32, "CPU"},
// grouped-query cases: qk_heads != v_heads
{1, 16, 2, 4, 16, 16, ov::element::f32, "CPU"},
{2, 8, 4, 8, 16, 16, ov::element::f32, "CPU"},
{1, 16, 2, 2, 128, 128, ov::element::f32, "CPU"},
{1, 16, 2, 2, 64, 128, ov::element::f32, "CPU"},
{1, 31, 2, 2, 128, 128, ov::element::f32, "CPU"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ TEST_P(GatedDeltaNet, CompareWithRefs) {
auto function = compiledModel.get_runtime_model();
CheckNumberOfNodesWithType(function, {"GatedDeltaNet"}, 1);
CheckNumberOfNodesWithType(function, {"Transpose"}, 0);
CheckNumberOfNodesWithType(function, {"Concat"}, 0);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this relate to the fix in the CPU plugin?

CheckNumberOfNodesWithType(function, {"ReduceSum"}, 0);
CheckNumberOfNodesWithType(function, {"Multiply"}, 0);
CheckNumberOfNodesWithType(function, {"Divide"}, 0);
Expand Down
117 changes: 108 additions & 9 deletions src/tests/functional/plugin/shared/src/subgraph/gated_delta_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <climits>
#include <cmath>
#include <utility>
#include <vector>

#include "common_test_utils/ov_tensor_utils.hpp"
#include "openvino/core/type/bfloat16.hpp"
Expand All @@ -19,10 +21,15 @@
#include "openvino/op/exp.hpp"
#include "openvino/op/gated_delta_net.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/gather_nd.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/loop.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/non_zero.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/reshape.hpp"
Expand All @@ -34,6 +41,7 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/runtime/properties.hpp"

namespace ov {
Expand All @@ -48,8 +56,8 @@ std::shared_ptr<ov::Model> GatedDeltaNet::buildLoopedGDN(int32_t batch,
ov::element::Type dtype) {
const ov::PartialShape qk_shape{batch, seq_len, qk_head_num, qk_head_size};
const ov::PartialShape v_tensor_shape{batch, seq_len, v_head_num, v_head_size};
const ov::PartialShape gv_shape{batch, seq_len, qk_head_num};
const ov::PartialShape h_shape{batch, qk_head_num, qk_head_size, v_head_size};
const ov::PartialShape gv_shape{batch, seq_len, v_head_num};
const ov::PartialShape h_shape{batch, v_head_num, qk_head_size, v_head_size};

auto q = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
auto k = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
Expand All @@ -65,6 +73,39 @@ std::shared_ptr<ov::Model> GatedDeltaNet::buildLoopedGDN(int32_t batch,
g->set_friendly_name("g");
beta->set_friendly_name("beta");

const bool need_head_repeat = (qk_head_num != v_head_num);

auto repeat_qk_heads =
[&](const ov::Output<ov::Node>& query,
const ov::Output<ov::Node>& key) -> std::pair<ov::Output<ov::Node>, ov::Output<ov::Node>> {
if (!need_head_repeat) {
return {query, key};
}

const int64_t group_size = static_cast<int64_t>(v_head_num / qk_head_num);

std::vector<int64_t> repeated_head_ids;
repeated_head_ids.reserve(static_cast<size_t>(v_head_num));
for (int64_t h = 0; h < static_cast<int64_t>(v_head_num); ++h) {
repeated_head_ids.push_back(h / group_size);
}

auto gather_indices =
ov::op::v0::Constant::create(ov::element::i64, {static_cast<size_t>(v_head_num)}, repeated_head_ids);
auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});

auto repeated_q = std::make_shared<ov::op::v8::Gather>(query, gather_indices, gather_axis, 0);
auto repeated_k = std::make_shared<ov::op::v8::Gather>(key, gather_indices, gather_axis, 0);

auto repeated_shape = ov::op::v0::Constant::create(
ov::element::i64,
ov::Shape{4},
std::vector<int64_t>{0, 0, static_cast<int64_t>(v_head_num), static_cast<int64_t>(qk_head_size)});
auto repeated_q_reshape = std::make_shared<ov::op::v1::Reshape>(repeated_q, repeated_shape, true);
auto repeated_k_reshape = std::make_shared<ov::op::v1::Reshape>(repeated_k, repeated_shape, true);
return {repeated_q_reshape, repeated_k_reshape};
};

auto l2norm = [&](const ov::Output<ov::Node>& x) {
auto sq = std::make_shared<ov::op::v1::Multiply>(x, x);
auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {-1});
Expand All @@ -76,18 +117,76 @@ std::shared_ptr<ov::Model> GatedDeltaNet::buildLoopedGDN(int32_t batch,
return std::make_shared<ov::op::v1::Multiply>(x, inv);
};

auto q_norm = l2norm(q);
auto k_norm = l2norm(k);
ov::Output<ov::Node> q_for_attn = q;
ov::Output<ov::Node> k_for_attn = k;
ov::Output<ov::Node> v_for_attn = v;

if (need_head_repeat) {
auto flatten_q_shape = ov::op::v0::Constant::create(ov::element::i64,
ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(1),
static_cast<int64_t>(-1)});
auto flatten_v_shape = ov::op::v0::Constant::create(ov::element::i64,
ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(1),
static_cast<int64_t>(-1)});

auto q_flat = std::make_shared<ov::op::v1::Reshape>(q, flatten_q_shape, true);
auto k_flat = std::make_shared<ov::op::v1::Reshape>(k, flatten_q_shape, true);
auto v_flat = std::make_shared<ov::op::v1::Reshape>(v, flatten_v_shape, true);

auto qkv_concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_flat, k_flat, v_flat}, -1);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1});
auto split_lengths =
ov::op::v0::Constant::create(ov::element::i64,
{3},
{static_cast<int64_t>(qk_head_num) * static_cast<int64_t>(qk_head_size),
static_cast<int64_t>(qk_head_num) * static_cast<int64_t>(qk_head_size),
static_cast<int64_t>(v_head_num) * static_cast<int64_t>(v_head_size)});
auto qkv_split = std::make_shared<ov::op::v1::VariadicSplit>(qkv_concat, split_axis, split_lengths);

auto q_shape = ov::op::v0::Constant::create(ov::element::i64,
ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(qk_head_num),
static_cast<int64_t>(qk_head_size)});
auto k_shape = ov::op::v0::Constant::create(ov::element::i64,
ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(qk_head_num),
static_cast<int64_t>(qk_head_size)});
auto v_shape_split = ov::op::v0::Constant::create(ov::element::i64,
ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(v_head_num),
static_cast<int64_t>(v_head_size)});

q_for_attn = std::make_shared<ov::op::v1::Reshape>(qkv_split->output(0), q_shape, true);
k_for_attn = std::make_shared<ov::op::v1::Reshape>(qkv_split->output(1), k_shape, true);
v_for_attn = std::make_shared<ov::op::v1::Reshape>(qkv_split->output(2), v_shape_split, true);

std::tie(q_for_attn, k_for_attn) = repeat_qk_heads(q_for_attn, k_for_attn);
}

auto v_shape = std::make_shared<ov::op::v3::ShapeOf>(v);
auto q_norm = l2norm(q_for_attn);
auto k_norm = l2norm(k_for_attn);

auto v_shape = std::make_shared<ov::op::v3::ShapeOf>(v_for_attn);
auto core_attn_init =
std::make_shared<ov::op::v3::Broadcast>(ov::op::v0::Constant::create(dtype, {}, {0.0f}), v_shape);

auto perm_bhsd = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});
auto perm_bhs = ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1});
auto q_norm_t = std::make_shared<ov::op::v1::Transpose>(q_norm, perm_bhsd);

auto shape_of_q = std::make_shared<ov::op::v3::ShapeOf>(q);
auto shape_of_q = std::make_shared<ov::op::v3::ShapeOf>(q_for_attn);
auto gather_q_perm_index = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});
auto gather_0_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
auto gather_q_shape = std::make_shared<ov::op::v8::Gather>(shape_of_q, gather_q_perm_index, gather_0_axis, 0);
Expand All @@ -100,7 +199,7 @@ std::shared_ptr<ov::Model> GatedDeltaNet::buildLoopedGDN(int32_t batch,

auto q_scaled_t = std::make_shared<ov::op::v1::Divide>(q_norm_t, q_scale);
auto k_norm_t = std::make_shared<ov::op::v1::Transpose>(k_norm, perm_bhsd);
auto v_t = std::make_shared<ov::op::v1::Transpose>(v, perm_bhsd);
auto v_t = std::make_shared<ov::op::v1::Transpose>(v_for_attn, perm_bhsd);
auto g_t = std::make_shared<ov::op::v1::Transpose>(g, perm_bhs);
auto beta_t = std::make_shared<ov::op::v1::Transpose>(beta, perm_bhs);

Expand Down Expand Up @@ -275,10 +374,10 @@ void GatedDeltaNet::SetUp() {
static_cast<size_t>(v_head_num),
static_cast<size_t>(v_head_size)};
const ov::Shape h_shape{static_cast<size_t>(batch),
static_cast<size_t>(qk_head_num),
static_cast<size_t>(v_head_num),
static_cast<size_t>(qk_head_size),
static_cast<size_t>(v_head_size)};
const ov::Shape g_shape{static_cast<size_t>(batch), static_cast<size_t>(seq_len), static_cast<size_t>(qk_head_num)};
const ov::Shape g_shape{static_cast<size_t>(batch), static_cast<size_t>(seq_len), static_cast<size_t>(v_head_num)};

init_input_shapes(static_shapes_to_test_representation({q_shape, q_shape, v_shape, h_shape, g_shape, g_shape}));

Expand Down
Loading