diff --git a/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp b/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp index ffef0770311c97..f0e93e59cc3294 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp @@ -93,34 +93,38 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::CpuParallelPtr& cpu_parallel) { size_t B = query.m_dims[0]; size_t T = query.m_dims[1]; - size_t H = query.m_dims[2]; + size_t qk_heads = query.m_dims[2]; size_t K = query.m_dims[3]; + size_t v_heads = value.m_dims[2]; size_t V = value.m_dims[3]; const size_t K_HEAD_DIMS = K; const size_t V_HEAD_DIMS = V; const float q_scale = 1 / std::sqrt(static_cast(K_HEAD_DIMS)); - cpu_parallel->parallel_for3d(B, H, V, [&](size_t i_b, size_t i_h, size_t i_v) { + const size_t group_size = v_heads / qk_heads; + cpu_parallel->parallel_for3d(B, v_heads, V, [&](size_t i_b, size_t i_h, size_t i_v) { size_t tid = parallel_get_thread_num(); float* init_state = temp_buffer + tid * 3 * K_HEAD_DIMS; float* b_k = temp_buffer + tid * 3 * K_HEAD_DIMS + K_HEAD_DIMS; float* b_q = temp_buffer + tid * 3 * K_HEAD_DIMS + 2 * K_HEAD_DIMS; - // B, T, H, K - float* q_ptr = query.ptr(i_b, 0, i_h); - float* k_ptr = key.ptr(i_b, 0, i_h); + const size_t hk = i_h / group_size; + // B, T, qk, K + float* q_ptr = query.ptr(i_b, 0, hk); + float* k_ptr = key.ptr(i_b, 0, hk); + // B, T, v_heads, V float* v_ptr = value.ptr(i_b, 0, i_h); - // B, H, K, V + // B, v_heads, K, V for (size_t j = 0; j < K_HEAD_DIMS; j++) { init_state[j] = recurrent_state.at({i_b, i_h, j, i_v}); } for (size_t i = 0; i < T; i++) { - // gate: B, T, H + // gate: B, T, v_heads float b_g = gate.at({i_b, i, i_h}); float b_beta = beta.at({i_b, i, i_h}); b_g = exp(b_g); for (size_t j = 0; j < K_HEAD_DIMS; j++) { - b_k[j] = k_ptr[i * H * K_HEAD_DIMS + j]; - b_q[j] = q_ptr[i * H * K_HEAD_DIMS + j]; + b_k[j] = k_ptr[i * qk_heads * K_HEAD_DIMS + j]; + b_q[j] = q_ptr[i * qk_heads * K_HEAD_DIMS + j]; } if (use_qk_l2norm) { l2norm(b_k, K_HEAD_DIMS, k_l2_norm_eps); @@ -130,8 +134,8 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, // h0 * gate multiply_scalar(init_state, init_state, b_g, K_HEAD_DIMS); float h_k = dot_product(init_state, b_k, K_HEAD_DIMS, nullptr, nullptr, nullptr, 0); - // B, T, H, V - float b_v = v_ptr[i_v + i * H * V_HEAD_DIMS]; + // B, T, v_heads, V + float b_v = v_ptr[i_v + i * v_heads * V_HEAD_DIMS]; b_v -= h_k; // b_v * b_k b_v *= b_beta; diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp index c7797c5d580fe3..0ad3d10582581f 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp @@ -11,6 +11,9 @@ std::vector test_cases = { {1, 39, 2, 2, 16, 32, ov::element::f32, "CPU"}, {2, 16, 2, 2, 16, 16, ov::element::f32, "CPU"}, {2, 39, 2, 2, 16, 16, ov::element::f32, "CPU"}, + // grouped-query cases: qk_heads != v_heads + {1, 16, 2, 4, 16, 16, ov::element::f32, "CPU"}, + {2, 8, 4, 8, 16, 16, ov::element::f32, "CPU"}, {1, 16, 2, 2, 128, 128, ov::element::f32, "CPU"}, {1, 16, 2, 2, 64, 128, ov::element::f32, "CPU"}, {1, 31, 2, 2, 128, 128, ov::element::f32, "CPU"}, diff --git a/src/tests/functional/plugin/shared/include/subgraph_tests/gated_delta_net.hpp b/src/tests/functional/plugin/shared/include/subgraph_tests/gated_delta_net.hpp index 526381e13b97ac..ad1d82da21e46c 100644 --- a/src/tests/functional/plugin/shared/include/subgraph_tests/gated_delta_net.hpp +++ b/src/tests/functional/plugin/shared/include/subgraph_tests/gated_delta_net.hpp @@ -14,7 +14,6 @@ TEST_P(GatedDeltaNet, CompareWithRefs) { auto function = compiledModel.get_runtime_model(); CheckNumberOfNodesWithType(function, {"GatedDeltaNet"}, 1); CheckNumberOfNodesWithType(function, {"Transpose"}, 0); - CheckNumberOfNodesWithType(function, {"Concat"}, 0); CheckNumberOfNodesWithType(function, {"ReduceSum"}, 0); CheckNumberOfNodesWithType(function, {"Multiply"}, 0); CheckNumberOfNodesWithType(function, {"Divide"}, 0); diff --git a/src/tests/functional/plugin/shared/src/subgraph/gated_delta_net.cpp b/src/tests/functional/plugin/shared/src/subgraph/gated_delta_net.cpp index 056a2e33a53b2a..ea40fbf2f16b1a 100644 --- a/src/tests/functional/plugin/shared/src/subgraph/gated_delta_net.cpp +++ b/src/tests/functional/plugin/shared/src/subgraph/gated_delta_net.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include #include "common_test_utils/ov_tensor_utils.hpp" #include "openvino/core/type/bfloat16.hpp" @@ -19,10 +21,15 @@ #include "openvino/op/exp.hpp" #include "openvino/op/gated_delta_net.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/gather_nd.hpp" +#include "openvino/op/less.hpp" #include "openvino/op/loop.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/non_zero.hpp" #include "openvino/op/parameter.hpp" #include "openvino/op/power.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_max.hpp" #include "openvino/op/reduce_prod.hpp" #include "openvino/op/reduce_sum.hpp" #include "openvino/op/reshape.hpp" @@ -34,6 +41,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/op/variadic_split.hpp" #include "openvino/runtime/properties.hpp" namespace ov { @@ -48,8 +56,8 @@ std::shared_ptr GatedDeltaNet::buildLoopedGDN(int32_t batch, ov::element::Type dtype) { const ov::PartialShape qk_shape{batch, seq_len, qk_head_num, qk_head_size}; const ov::PartialShape v_tensor_shape{batch, seq_len, v_head_num, v_head_size}; - const ov::PartialShape gv_shape{batch, seq_len, qk_head_num}; - const ov::PartialShape h_shape{batch, qk_head_num, qk_head_size, v_head_size}; + const ov::PartialShape gv_shape{batch, seq_len, v_head_num}; + const ov::PartialShape h_shape{batch, v_head_num, qk_head_size, v_head_size}; auto q = std::make_shared(dtype, qk_shape); auto k = std::make_shared(dtype, qk_shape); @@ -65,6 +73,39 @@ std::shared_ptr GatedDeltaNet::buildLoopedGDN(int32_t batch, g->set_friendly_name("g"); beta->set_friendly_name("beta"); + const bool need_head_repeat = (qk_head_num != v_head_num); + + auto repeat_qk_heads = + [&](const ov::Output& query, + const ov::Output& key) -> std::pair, ov::Output> { + if (!need_head_repeat) { + return {query, key}; + } + + const int64_t group_size = static_cast(v_head_num / qk_head_num); + + std::vector repeated_head_ids; + repeated_head_ids.reserve(static_cast(v_head_num)); + for (int64_t h = 0; h < static_cast(v_head_num); ++h) { + repeated_head_ids.push_back(h / group_size); + } + + auto gather_indices = + ov::op::v0::Constant::create(ov::element::i64, {static_cast(v_head_num)}, repeated_head_ids); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}); + + auto repeated_q = std::make_shared(query, gather_indices, gather_axis, 0); + auto repeated_k = std::make_shared(key, gather_indices, gather_axis, 0); + + auto repeated_shape = ov::op::v0::Constant::create( + ov::element::i64, + ov::Shape{4}, + std::vector{0, 0, static_cast(v_head_num), static_cast(qk_head_size)}); + auto repeated_q_reshape = std::make_shared(repeated_q, repeated_shape, true); + auto repeated_k_reshape = std::make_shared(repeated_k, repeated_shape, true); + return {repeated_q_reshape, repeated_k_reshape}; + }; + auto l2norm = [&](const ov::Output& x) { auto sq = std::make_shared(x, x); auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {-1}); @@ -76,10 +117,68 @@ std::shared_ptr GatedDeltaNet::buildLoopedGDN(int32_t batch, return std::make_shared(x, inv); }; - auto q_norm = l2norm(q); - auto k_norm = l2norm(k); + ov::Output q_for_attn = q; + ov::Output k_for_attn = k; + ov::Output v_for_attn = v; + + if (need_head_repeat) { + auto flatten_q_shape = ov::op::v0::Constant::create(ov::element::i64, + ov::Shape{4}, + std::vector{static_cast(0), + static_cast(0), + static_cast(1), + static_cast(-1)}); + auto flatten_v_shape = ov::op::v0::Constant::create(ov::element::i64, + ov::Shape{4}, + std::vector{static_cast(0), + static_cast(0), + static_cast(1), + static_cast(-1)}); + + auto q_flat = std::make_shared(q, flatten_q_shape, true); + auto k_flat = std::make_shared(k, flatten_q_shape, true); + auto v_flat = std::make_shared(v, flatten_v_shape, true); + + auto qkv_concat = std::make_shared(ov::OutputVector{q_flat, k_flat, v_flat}, -1); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}); + auto split_lengths = + ov::op::v0::Constant::create(ov::element::i64, + {3}, + {static_cast(qk_head_num) * static_cast(qk_head_size), + static_cast(qk_head_num) * static_cast(qk_head_size), + static_cast(v_head_num) * static_cast(v_head_size)}); + auto qkv_split = std::make_shared(qkv_concat, split_axis, split_lengths); + + auto q_shape = ov::op::v0::Constant::create(ov::element::i64, + ov::Shape{4}, + std::vector{static_cast(0), + static_cast(0), + static_cast(qk_head_num), + static_cast(qk_head_size)}); + auto k_shape = ov::op::v0::Constant::create(ov::element::i64, + ov::Shape{4}, + std::vector{static_cast(0), + static_cast(0), + static_cast(qk_head_num), + static_cast(qk_head_size)}); + auto v_shape_split = ov::op::v0::Constant::create(ov::element::i64, + ov::Shape{4}, + std::vector{static_cast(0), + static_cast(0), + static_cast(v_head_num), + static_cast(v_head_size)}); + + q_for_attn = std::make_shared(qkv_split->output(0), q_shape, true); + k_for_attn = std::make_shared(qkv_split->output(1), k_shape, true); + v_for_attn = std::make_shared(qkv_split->output(2), v_shape_split, true); + + std::tie(q_for_attn, k_for_attn) = repeat_qk_heads(q_for_attn, k_for_attn); + } - auto v_shape = std::make_shared(v); + auto q_norm = l2norm(q_for_attn); + auto k_norm = l2norm(k_for_attn); + + auto v_shape = std::make_shared(v_for_attn); auto core_attn_init = std::make_shared(ov::op::v0::Constant::create(dtype, {}, {0.0f}), v_shape); @@ -87,7 +186,7 @@ std::shared_ptr GatedDeltaNet::buildLoopedGDN(int32_t batch, auto perm_bhs = ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}); auto q_norm_t = std::make_shared(q_norm, perm_bhsd); - auto shape_of_q = std::make_shared(q); + auto shape_of_q = std::make_shared(q_for_attn); auto gather_q_perm_index = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}); auto gather_0_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); auto gather_q_shape = std::make_shared(shape_of_q, gather_q_perm_index, gather_0_axis, 0); @@ -100,7 +199,7 @@ std::shared_ptr GatedDeltaNet::buildLoopedGDN(int32_t batch, auto q_scaled_t = std::make_shared(q_norm_t, q_scale); auto k_norm_t = std::make_shared(k_norm, perm_bhsd); - auto v_t = std::make_shared(v, perm_bhsd); + auto v_t = std::make_shared(v_for_attn, perm_bhsd); auto g_t = std::make_shared(g, perm_bhs); auto beta_t = std::make_shared(beta, perm_bhs); @@ -275,10 +374,10 @@ void GatedDeltaNet::SetUp() { static_cast(v_head_num), static_cast(v_head_size)}; const ov::Shape h_shape{static_cast(batch), - static_cast(qk_head_num), + static_cast(v_head_num), static_cast(qk_head_size), static_cast(v_head_size)}; - const ov::Shape g_shape{static_cast(batch), static_cast(seq_len), static_cast(qk_head_num)}; + const ov::Shape g_shape{static_cast(batch), static_cast(seq_len), static_cast(v_head_num)}; init_input_shapes(static_shapes_to_test_representation({q_shape, q_shape, v_shape, h_shape, g_shape, g_shape}));