diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp index 2a75c1385c36e4..f38104caa82490 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp @@ -52,33 +52,34 @@ ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, Upda auto adaptive_rkv_evictable_sizes = pattern::any_input(pattern::has_static_rank()); auto adaptive_rkv_diversity_block_set_indices = pattern::any_input(pattern::has_static_rank()); auto adaptive_rkv_diversity_block_set_indices_begins = pattern::any_input(pattern::has_static_rank()); - - auto result = - pattern::wrap_type({Q, - K, - V, - key_cache_0, - value_cache_0, - past_lens, - subsequence_begins, - block_indices, - block_indices_begins, - scale, - sliding_window, - alibi_slopes, - max_context_len, - score_aggregation_window, - rotated_block_indices, - rotation_deltas, - rotation_trig_lut, - xattention_threshold, - xattention_block_size, - xattention_stride, - sinks, - adaptive_rkv_start_size, - adaptive_rkv_evictable_sizes, - adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + auto token_type_ids = pattern::any_input(pattern::has_static_rank()); + + auto result = pattern::wrap_type({Q, + K, + V, + key_cache_0, + value_cache_0, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + score_aggregation_window, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, + xattention_threshold, + xattention_block_size, + xattention_stride, + sinks, + adaptive_rkv_start_size, + adaptive_rkv_evictable_sizes, + adaptive_rkv_diversity_block_set_indices, + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) { const auto pa_op = m.get_match_root(); auto key_cache = ov::as_type_ptr(pa_op->get_input_node_shared_ptr(3)); diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index 5c317d9b767cfd..acae75479b1c71 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -13,6 +13,7 @@ #include "openvino/op/bitwise_and.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/greater.hpp" @@ -197,6 +198,18 @@ static std::shared_ptr handle_baichuan2_13b_alibi( return res_alibi_slopes; } +static std::shared_ptr handle_gemma3_token_type_ids( + const std::map>& optional_model_wide_params) { + if (optional_model_wide_params.find("token_type_ids") != optional_model_wide_params.end()) { + auto param = optional_model_wide_params.at("token_type_ids"); + if (param->get_element_type() != ov::element::i32) { + return std::make_shared(param, ov::element::i32); + } + return param; + } + return v0::Constant::create(ov::element::i32, ov::Shape{0}, {}); +} + static std::tuple, std::shared_ptr> phi3_sliding_window_pattern() { auto offset = wrap_type(); auto t196 = wrap_type({any_input(), offset}); @@ -216,7 +229,7 @@ static std::tuple, std::shared_ptr> phi3_sli return {mask, offset}; } -static std::tuple, std::shared_ptr> gpt_oss_sliding_window_pattern() { +static std::tuple, std::shared_ptr> gptoss_gemma3_sliding_window_pattern() { auto q_idx = any_input(); auto kv_idx = any_input(); @@ -393,9 +406,9 @@ ov::pass::StateManagementPattern::StateManagementPattern( std::shared_ptr phi3_mask, phi3_offset; std::tie(phi3_mask, phi3_offset) = phi3_sliding_window_pattern(); - // gpt-oss case - std::shared_ptr gpt_oss_mask, gpt_oss_offset; - std::tie(gpt_oss_mask, gpt_oss_offset) = gpt_oss_sliding_window_pattern(); + // gpt-oss and gemma3 cases + std::shared_ptr gptoss_gemma3_mask, gptoss_gemma3_offset; + std::tie(gptoss_gemma3_mask, gptoss_gemma3_offset) = gptoss_gemma3_sliding_window_pattern(); // Scale's shape limitations according to SDPA specification auto scale_predicate = [=](const Output& output) -> bool { @@ -414,7 +427,7 @@ ov::pass::StateManagementPattern::StateManagementPattern( general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, - gpt_oss_mask, + gptoss_gemma3_mask, any_input()}); auto sdpa_with_4_inputs = wrap_type({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa}); @@ -425,6 +438,11 @@ ov::pass::StateManagementPattern::StateManagementPattern( auto sdpa_variants = std::make_shared(OutputVector{sdpa_with_4_inputs, sdpa_with_5_inputs, sdpa_with_6_inputs}); + // Shared flag to track whether the model is Gemma3, set when any layer matches + // the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check, + // this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids). + auto is_gptoss_gemma3 = std::make_shared(false); + ov::matcher_pass_callback callback = [=, &kv_parameters, &model_wide_params, @@ -602,9 +620,10 @@ ov::pass::StateManagementPattern::StateManagementPattern( offset = std::make_shared(offset, element::i32); } sliding_window = std::make_shared(v0::Constant::create(element::i32, Shape{}, {2}), offset); - } else if (pattern_map.count(gpt_oss_offset)) { - auto offset = pattern_map.at(gpt_oss_offset).get_node_shared_ptr(); - if (pattern_map.at(gpt_oss_offset).get_partial_shape().rank() != 0) { + } else if (pattern_map.count(gptoss_gemma3_offset)) { + *is_gptoss_gemma3 = true; + auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr(); + if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) { offset = std::make_shared(offset); } if (offset->get_element_type() != element::i32) { @@ -737,6 +756,13 @@ ov::pass::StateManagementPattern::StateManagementPattern( } OPENVINO_ASSERT(pa_arguments.size() == 25); + if (*is_gptoss_gemma3) { + pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params)); + } else { + pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {})); + } + OPENVINO_ASSERT(pa_arguments.size() == 26); + auto paged_attention = std::make_shared(pa_arguments); paged_attention->get_rt_info()[NUM_K_HEADS] = num_k_heads; paged_attention->get_rt_info()[K_HEAD_SIZE] = k_head_size; diff --git a/src/common/transformations/tests/common_optimizations/convert_pagedattn_inputs.cpp b/src/common/transformations/tests/common_optimizations/convert_pagedattn_inputs.cpp index 964151c724177b..5b2c79f006498f 100644 --- a/src/common/transformations/tests/common_optimizations/convert_pagedattn_inputs.cpp +++ b/src/common/transformations/tests/common_optimizations/convert_pagedattn_inputs.cpp @@ -121,33 +121,35 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) { std::make_shared(ov::element::i32, PartialShape{DYN}); auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(ov::element::i32, PartialShape{DYN}); + auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}); - auto pa = std::make_shared( - OutputVector{Q, - K, - V, - key_cache_0, - value_cache_0, - past_lens, - subsequence_begins, - block_indices, - block_indices_begins, - scale, - sliding_window, - alibi_slopes, - max_context_len, - score_aggregation_window, - rotated_block_indices, - rotation_deltas, - rotation_trig_lut, - xattention_threshold, - xattention_block_size, - xattention_stride, - sinks, - adaptive_rkv_start_size, - adaptive_rkv_evictable_sizes, - adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + auto pa = + std::make_shared(OutputVector{Q, + K, + V, + key_cache_0, + value_cache_0, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + score_aggregation_window, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, + xattention_threshold, + xattention_block_size, + xattention_stride, + sinks, + adaptive_rkv_start_size, + adaptive_rkv_evictable_sizes, + adaptive_rkv_diversity_block_set_indices, + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); pa->get_rt_info()["num_k_heads"] = numKeyHeads; pa->get_rt_info()["k_head_size"] = keyHeadSize; pa->get_rt_info()["num_v_heads"] = numValueHeads; @@ -174,7 +176,8 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); if (isIRKVCacheF16) { model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); @@ -254,33 +257,35 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) { std::make_shared(ov::element::i32, PartialShape{DYN}); auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(ov::element::i32, PartialShape{DYN}); + auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}); - auto pa = std::make_shared( - OutputVector{Q, - K, - V, - key_cache_0, - value_cache_0, - past_lens, - subsequence_begins, - block_indices, - block_indices_begins, - scale, - sliding_window, - alibi_slopes, - max_context_len, - score_aggregation_window, - rotated_block_indices, - rotation_deltas, - rotation_trig_lut, - xattention_threshold, - xattention_block_size, - xattention_stride, - sinks, - adaptive_rkv_start_size, - adaptive_rkv_evictable_sizes, - adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + auto pa = + std::make_shared(OutputVector{Q, + K, + V, + key_cache_0, + value_cache_0, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + score_aggregation_window, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, + xattention_threshold, + xattention_block_size, + xattention_stride, + sinks, + adaptive_rkv_start_size, + adaptive_rkv_evictable_sizes, + adaptive_rkv_diversity_block_set_indices, + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); pa->get_rt_info()["num_k_heads"] = numKeyHeads; pa->get_rt_info()["k_head_size"] = keyHeadSize; pa->get_rt_info()["num_v_heads"] = numValueHeads; @@ -307,7 +312,8 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); } ov::pass::ConvertPagedAttnInputs::KVCacheConfig cacheConfig; cacheConfig.keyCacheBlockSize = blockSize[0]; diff --git a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp index caaefd11e0525e..9378533a1e9d8d 100644 --- a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp +++ b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp @@ -619,34 +619,36 @@ TEST_P(SDPAToPATest, SDPAToPA_Qwen7bChat_General) { auto scale = std::make_shared(element::f32, Shape{}, MOCK_VALUE); auto score_aggregation_window_const = std::make_shared(element::i32, Shape{0}, 0); auto sinks = v0::Constant::create(element::f32, Shape{0, 0, 0, 0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); // PagedAttention: - auto pa = std::make_shared( - OutputVector{Q, - K, - V, - key_cache_0, - value_cache_0, - past_lens, - subsequence_begins, - block_indices, - block_indices_begins, - scale, - sliding_window, - alibi_slopes, - max_context_len, - score_aggregation_window_const, - rotated_block_indices, - rotation_deltas, - rotation_trig_lut, - xattention_threshold, - xattention_block_size, - xattention_stride, - sinks, - adaptive_rkv_start_size, - adaptive_rkv_evictable_sizes, - adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + auto pa = + std::make_shared(OutputVector{Q, + K, + V, + key_cache_0, + value_cache_0, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + score_aggregation_window_const, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, + xattention_threshold, + xattention_block_size, + xattention_stride, + sinks, + adaptive_rkv_start_size, + adaptive_rkv_evictable_sizes, + adaptive_rkv_diversity_block_set_indices, + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); pa->set_out_type(0, element::i64); auto pa_aligned = Qwen7bChatPA::align_pa_layout(pa, head_size_2); auto res = makeOP({pa_aligned}); @@ -1011,6 +1013,7 @@ TEST_F(SDPAToPATest, SDPAToPA_Baichuan2_13b_General) { auto c1 = makeConst(element::f32, {}, {0.088388f}); auto c2 = makeConst(element::i32, {}, {0}); auto sinks = v0::Constant::create(element::f32, Shape{0, 0, 0, 0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension168 = std::make_shared( ov::OutputVector{Reshape138, Reshape147, @@ -1036,7 +1039,8 @@ TEST_F(SDPAToPATest, SDPAToPA_Baichuan2_13b_General) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); auto ShapeOf172 = makeOP({Transpose154}, {{"output_type", "i64"}}); auto Gather175 = makeOP({ShapeOf172, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze177 = makeOP({Gather175, 0}); @@ -1382,6 +1386,7 @@ TEST_F(SDPAToPATest, SDPAToPA_nanoLLaVA_General) { // an empty Constant needs to be created in a usual way, not using makeConst() auto c3 = v0::Constant::create(element::f32, {0}, {}); auto sinks = v0::Constant::create(element::f32, Shape{0, 0, 0, 0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension_51962 = std::make_shared( ov::OutputVector{Reshape_51953, Reshape_51957, @@ -1407,7 +1412,8 @@ TEST_F(SDPAToPATest, SDPAToPA_nanoLLaVA_General) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); auto ShapeOf_51965 = makeOP({Transpose_51955}, {{"output_type", "i64"}}); auto Gather_51966 = makeOP({ShapeOf_51965, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze_51971 = makeOP({Gather_51966, 0}); @@ -1709,6 +1715,7 @@ TEST_F(SDPAToPATest, SDPAToPA_Phi3_mini_4k_instruct) { auto scale = v0::Constant::create(element::f32, {}, {0.102062f}); auto alibi_slopes = v0::Constant::create(element::f32, Shape{0}, {}); auto sinks = v0::Constant::create(element::f32, Shape{0, 0, 0, 0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension = std::make_shared( OutputVector{Q, K, @@ -1734,7 +1741,8 @@ TEST_F(SDPAToPATest, SDPAToPA_Phi3_mini_4k_instruct) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); auto ShapeOf1 = makeOP({Transpose6}, {{"output_type", "i64"}}); auto Gather2 = makeOP({ShapeOf1, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze5 = makeOP({Gather2, 0}); @@ -2055,6 +2063,7 @@ TEST_F(SDPAToPATest, SDPAToPA_Codegen2) { auto scale = v0::Constant::create(element::f32, {}, {0.062500f}); auto alibi_slopes_stub = v0::Constant::create(element::f32, Shape{0}, {}); auto sinks = v0::Constant::create(element::f32, Shape{0, 0, 0, 0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension = std::make_shared( OutputVector{Reshape11, Reshape13, @@ -2080,7 +2089,8 @@ TEST_F(SDPAToPATest, SDPAToPA_Codegen2) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); auto ShapeOf2 = makeOP({Transpose7}, {{"output_type", "i64"}}); auto Gather5 = makeOP({ShapeOf2, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze9 = makeOP({Gather5, 0}); @@ -2715,6 +2725,7 @@ TEST_F(SDPAToPATest, SDPAToPA_gpt_oss_General) { auto sliding_window = makeOP({Convert16, -1}, {{"auto_broadcast", "numpy"}}); auto scale = v0::Constant::create(element::f32, {}, {0.1250f}); auto alibi_slopes_stub = v0::Constant::create(element::f32, Shape{0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension = std::make_shared( OutputVector{Reshape1, Reshape3, @@ -2740,7 +2751,8 @@ TEST_F(SDPAToPATest, SDPAToPA_gpt_oss_General) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); auto ShapeOf3 = makeOP({Transpose6}, {{"output_type", "i64"}}); auto Gather4 = makeOP({ShapeOf3, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze5 = makeOP({Gather4, 0}); @@ -4129,6 +4141,7 @@ TEST_F(SDPAToPATest, SDPAToPA_LFM2) { auto scale = v0::Constant::create(element::f32, {}, {0.500000f}); auto sliding_window = v0::Constant::create(element::i32, {}, {0}); auto alibi_slopes_stub = v0::Constant::create(element::f32, Shape{0}, {}); + auto token_type_ids = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension0 = std::make_shared( OutputVector{Reshape18, Reshape20, @@ -4154,7 +4167,8 @@ TEST_F(SDPAToPATest, SDPAToPA_LFM2) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}); auto ShapeOf12 = makeOP({Transpose15}, {{"output_type", "i64"}}); auto Gather10 = makeOP({ShapeOf12, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze10 = makeOP({Gather10, 0}); @@ -4712,6 +4726,7 @@ TEST_F(SDPAToPATest, SDPAToPA_jais_13b_General) { auto Constant21 = makeConst(element::i32, ov::Shape({0}), {0}); auto Constant22 = makeConst(element::i32, ov::Shape({0}), {0}); auto Constant23 = makeConst(element::i32, ov::Shape({0}), {0}); + auto Constant24 = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension0 = make_shared(OutputVector{Reshape4, Reshape6, @@ -4737,7 +4752,8 @@ TEST_F(SDPAToPATest, SDPAToPA_jais_13b_General) { Constant20, Constant21, Constant22, - Constant23}); + Constant23, + Constant24}); auto ShapeOf1 = makeOP({Transpose5}, {{"output_type", "i64"}}); auto Gather2 = makeOP({ShapeOf1, -1, 0}, {{"batch_dims", 0}}); auto Unsqueeze1 = makeOP({Gather2, 0}); @@ -4981,7 +4997,6 @@ TEST_F(SDPAToPATest, SDPATOPATest_Qwen2_5_VL_General) { auto score_aggregation_window = make_param(PartialShape{DYN}, element::i32, "score_aggregation_window"); auto inputs_embeds = make_param(PartialShape{DYN, DYN}, element::f32, "inputs_embeds"); auto position_ids = make_param(PartialShape{3, DYN}, element::i64, "position_ids"); - auto Unsqueeze0 = makeOP({inputs_embeds, 1}); auto Const0 = makeConst(element::f32, ov::Shape({ @@ -5120,6 +5135,7 @@ TEST_F(SDPAToPATest, SDPATOPATest_Qwen2_5_VL_General) { auto adaptive_rkv_evictable_sizes = v0::Constant::create(element::i32, Shape{0}, {}); auto adaptive_rkv_diversity_block_set_indices = v0::Constant::create(element::i32, Shape{0}, {}); auto adaptive_rkv_diversity_block_set_indices_begins = v0::Constant::create(element::i32, Shape{0}, {}); + auto token_type_ids_stub = v0::Constant::create(element::i32, Shape{0}, {}); auto PagedAttentionExtension0 = make_shared(OutputVector{Reshape1, @@ -5146,7 +5162,8 @@ TEST_F(SDPAToPATest, SDPATOPATest_Qwen2_5_VL_General) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}); + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids_stub}); auto ShapeOf1 = makeOP({Transpose4}, {{"output_type", "i64"}}); auto Gather13 = makeOP({ShapeOf1, -1, 0}, {{"batch_dims", 0}}); @@ -5176,6 +5193,274 @@ TEST_F(SDPAToPATest, SDPATOPATest_Qwen2_5_VL_General) { } } +// Gemma3 test: same sliding window pattern as gpt_oss, but with token_type_ids as model parameter +TEST_F(SDPAToPATest, SDPAToPA_Gemma3_TokenTypeIds) { + { + auto beam_idx = make_param(PartialShape{DYN}, element::i32, "beam_idx"); + auto position_ids = make_param(PartialShape{DYN, DYN}, element::i64, "position_ids"); + auto attention_mask = make_param(PartialShape{DYN, DYN}, element::i64, "attention_mask"); + auto input_ids = make_param(PartialShape{DYN, DYN}, element::i64, "input_ids"); + auto token_type_ids = make_param(PartialShape{1, DYN}, element::i64, "token_type_ids"); + auto params = nodes_to_params({beam_idx, position_ids, attention_mask, input_ids, token_type_ids}); + + auto ShapeOf0 = makeOP({input_ids}, {{"output_type", "i64"}}); + auto Gather0 = makeOP({ShapeOf0, {0}, 0}, {{"batch_dims", 0}}); + + auto Constant0 = makeConst(element::f32, ov::Shape({32000, 128}), MOCK_VALUE); + auto Convert3 = makeOP({input_ids}, {{"destination_type", "i32"}}); + auto Gather1 = makeOP({Constant0, Convert3, 0}, {{"batch_dims", 0}}); + auto Power0 = makeOP({Gather1, single_val(3, 2.0f)}, {{"auto_broadcast", "numpy"}}); + auto ReduceMean0 = makeOP({Power0, {-1}}, {{"keep_dims", true}}); + auto Add0 = makeOP({ReduceMean0, single_val(3, 1e-6f)}, {{"auto_broadcast", "numpy"}}); + auto Sqrt0 = makeOP({Add0}); + auto Divide0 = + makeOP({single_val(3, 1.0f), Sqrt0}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Multiply1 = makeOP({Gather1, Divide0}, {{"auto_broadcast", "numpy"}}); + auto Constant_w = makeConst(element::f32, ov::Shape({1, 1, 128}), MOCK_VALUE); + auto Multiply2 = makeOP({Constant_w, Multiply1}, {{"auto_broadcast", "numpy"}}); + + auto q_weight = makeConst(element::f32, ov::Shape({512, 128}), MOCK_VALUE); + auto MatMul_q = makeOP({Multiply2, q_weight}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Reshape_q = makeOP({MatMul_q, {0, 0, 4, 128}}, {{"special_zero", true}}); + auto Q = makeOP({Reshape_q, {0, 2, 1, 3}}); + + auto k_weight = makeConst(element::f32, ov::Shape({128, 128}), MOCK_VALUE); + auto MatMul_k = makeOP({Multiply2, k_weight}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Reshape_k = makeOP({MatMul_k, {0, 0, 1, 128}}, {{"special_zero", true}}); + auto K_cur = makeOP({Reshape_k, {0, 2, 1, 3}}); + + auto v_weight = makeConst(element::f32, ov::Shape({128, 128}), MOCK_VALUE); + auto MatMul_v = makeOP({Multiply2, v_weight}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Reshape_v = makeOP({MatMul_v, {0, 0, 1, 128}}, {{"special_zero", true}}); + auto V_cur = makeOP({Reshape_v, {0, 2, 1, 3}}); + + auto k_init_shape = makeOP({Gather0, {1l}, {0l}, {128l}}, {{"axis", 0}}); + auto k_init = makeOP({0.0f, k_init_shape}, {{"mode", "numpy"}}); + auto k_read = makeOP( + {k_init}, + {{"variable_id", "k_cache"}, {"variable_type", "f32"}, {"variable_shape", PartialShape{DYN, 1, DYN, 128}}}); + auto k_past = makeOP({k_read, beam_idx, 0}, {{"batch_dims", 0}}); + auto k_concat = makeOP({k_past, K_cur}, {{"axis", -2}}); + + auto v_init_shape = makeOP({Gather0, {1l}, {0l}, {128l}}, {{"axis", 0}}); + auto v_init = makeOP({0.0f, v_init_shape}, {{"mode", "numpy"}}); + auto v_read = makeOP( + {v_init}, + {{"variable_id", "v_cache"}, {"variable_type", "f32"}, {"variable_shape", PartialShape{DYN, 1, DYN, 128}}}); + auto v_past = makeOP({v_read, beam_idx, 0}, {{"batch_dims", 0}}); + auto v_concat = makeOP({v_past, V_cur}, {{"axis", -2}}); + + auto k_unsqueeze = makeOP({k_concat, 2}); + auto k_shape = makeOP({k_concat}, {{"output_type", "i64"}}); + auto k_gather_dims = makeOP({k_shape, {0, 1}, 0}, {{"batch_dims", 0}}); + auto k_gather_dims2 = makeOP({k_shape, {2, 3}, 0}, {{"batch_dims", 0}}); + auto k_bcast_shape = makeOP({k_gather_dims, {4l}, k_gather_dims2}, {{"axis", 0}}); + auto k_broadcast = makeOP({k_unsqueeze, k_bcast_shape}, {{"mode", "bidirectional"}}); + auto K = makeOP({k_broadcast, {0, 4, -1, 128}}, {{"special_zero", true}}); + + auto v_unsqueeze = makeOP({v_concat, 2}); + auto V = makeOP( + {makeOP({v_unsqueeze, k_bcast_shape}, {{"mode", "bidirectional"}}), {0, 4, -1, 128}}, + {{"special_zero", true}}); + + // Same pattern as gpt_oss + auto Constant_true1 = makeConst(element::boolean, ov::Shape({}), {1}); + auto Constant_true2 = makeConst(element::boolean, ov::Shape({}), {1}); + + auto ShapeOf_pos = makeOP({position_ids}, {{"output_type", "i64"}}); + auto Gather_cur = makeOP({ShapeOf_pos, 1, 0}, {{"batch_dims", 0}}); + auto Reshape_cur = makeOP({Gather_cur, {1}}, {{"special_zero", false}}); + auto Squeeze_cur = makeOP({Reshape_cur, 0}); + + auto ShapeOf_past = makeOP({k_past}, {{"output_type", "i64"}}); + auto Gather_past = makeOP({ShapeOf_past, 2, 0}, {{"batch_dims", 0}}); + + auto total_len = makeOP({Squeeze_cur, Gather_past}, {{"auto_broadcast", "numpy"}}); + auto Range_kv = makeOP({0, total_len, 1}, {{"output_type", "i64"}}); + auto Unsqueeze_kv0 = makeOP({Range_kv, 0}); + auto Unsqueeze_kv1 = makeOP({Unsqueeze_kv0, 1}); + auto Unsqueeze_kv2 = makeOP({Unsqueeze_kv1, 2}); + auto kv_idx = makeOP({Unsqueeze_kv2}, {{"destination_type", "f32"}}); + + auto Range_q_start = makeOP({Gather_past, Gather_cur}, {{"auto_broadcast", "numpy"}}); + auto Range_q = makeOP({Gather_past, Range_q_start, 1}, {{"output_type", "f32"}}); + auto Unsqueeze_q0 = makeOP({Range_q, 0}); + auto Unsqueeze_q1 = makeOP({Unsqueeze_q0, 1}); + auto q_idx = makeOP({Unsqueeze_q1, 3}); + + auto sw_offset = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1024.0f}); + auto sw_add = makeOP({q_idx, sw_offset}, {{"auto_broadcast", "numpy"}}); + auto sw_greater = makeOP({kv_idx, sw_add}, {{"auto_broadcast", "numpy"}}); + + auto causal_le = makeOP({kv_idx, q_idx}, {{"auto_broadcast", "numpy"}}); + auto BitwiseAnd0 = makeOP({Constant_true2, sw_greater}, {{"auto_broadcast", "numpy"}}); + auto BitwiseAnd1 = makeOP({BitwiseAnd0, causal_le}, {{"auto_broadcast", "numpy"}}); + auto BitwiseAnd2 = makeOP({Constant_true1, BitwiseAnd1}, {{"auto_broadcast", "numpy"}}); + + auto Convert_am = makeOP({attention_mask}, {{"destination_type", "boolean"}}); + auto ShapeOf_am = makeOP({Convert_am}, {{"output_type", "i32"}}); + auto ReduceProd_am = makeOP({ShapeOf_am, 0}, {{"keep_dims", true}}); + auto Concat_am = makeOP({ReduceProd_am, {-1}}, {{"axis", 0}}); + auto Reshape_am = makeOP({Convert_am, Concat_am}, {{"special_zero", true}}); + auto kv_idx_i32 = makeOP({Unsqueeze_kv2}, {{"destination_type", "i32"}}); + auto Gather0_batch = makeOP({ShapeOf_pos, {0}, 0}, {{"batch_dims", 0}}); + auto Squeeze_batch = makeOP({Gather0_batch}); + auto Range_batch = makeOP({0, Squeeze_batch, 1}, {{"output_type", "i64"}}); + auto Unsq_b0 = makeOP({Range_batch, 1}); + auto Unsq_b1 = makeOP({Unsq_b0, 2}); + auto Unsq_b2 = makeOP({Unsq_b1, 3}); + auto batch_idx = makeOP({Unsq_b2}, {{"destination_type", "i32"}}); + auto Split_am = makeOP({ShapeOf_am, 0}, {{"num_splits", 2}}); + auto Multiply_idx = makeOP({batch_idx, Split_am->output(1)}, {{"auto_broadcast", "numpy"}}); + auto flat_idx = makeOP({kv_idx_i32, Multiply_idx}, {{"auto_broadcast", "numpy"}}); + auto Gather_am = makeOP({Reshape_am, flat_idx, 0}, {{"batch_dims", 0}}); + auto Reshape_am2 = makeOP({Gather_am, {-1}}, {{"special_zero", false}}); + auto ShapeOf_idx = makeOP({flat_idx}, {{"output_type", "i32"}}); + auto Reshape_am3 = makeOP({Reshape_am2, ShapeOf_idx}, {{"special_zero", false}}); + + auto BitwiseAnd3 = makeOP({BitwiseAnd2, Reshape_am3}, {{"auto_broadcast", "numpy"}}); + auto total_len_unsq = makeOP({total_len, 0}); + auto bcast_shape = makeOP({Gather0_batch, {1l}, Reshape_cur, total_len_unsq}, {{"axis", 0}}); + auto Broadcast_mask = makeOP({BitwiseAnd3, bcast_shape}, {{"mode", "bidirectional"}}); + auto Select_mask = makeOP({Broadcast_mask, 0.0f, -65504.0f}, {{"auto_broadcast", "numpy"}}); + auto past_len_reshape = makeOP({Gather_past, {1}}, {{"special_zero", false}}); + auto slice_end = makeOP({past_len_reshape, Reshape_cur}, {{"auto_broadcast", "numpy"}}); + auto Slice_mask = makeOP({Select_mask, {0}, slice_end, {1}, {3}}); + + auto ScaledDotProductAttention = + makeOP({Q, K, V, Slice_mask, 0.125f}, {{"causal", false}}); + auto res = make_shared(ScaledDotProductAttention); + + model = std::make_shared(OutputVector{res}, params); + manager.register_pass(); + } + + { + auto max_context_len = make_param(PartialShape{}, element::i32, "max_context_len"); + auto block_indices_begins = make_param(PartialShape{DYN}, element::i32, "block_indices_begins"); + auto block_indices = make_param(PartialShape{DYN}, element::i32, "block_indices"); + auto subsequence_begins = make_param(PartialShape{DYN}, element::i32, "subsequence_begins"); + auto past_lens = make_param(PartialShape{DYN}, element::i32, "past_lens"); + auto value_cache_0 = make_param(PartialShape{DYN, DYN, DYN, DYN}, element::dynamic, "value_cache.0"); + auto key_cache_0 = make_param(PartialShape{DYN, DYN, DYN, DYN}, element::dynamic, "key_cache.0"); + auto input_ids = make_param(PartialShape{DYN}, element::i64, "input_ids"); + auto position_ids = make_param(PartialShape{DYN}, element::i64, "position_ids"); + auto token_type_ids_param = make_param(PartialShape{1, DYN}, element::i64, "token_type_ids"); + + auto score_aggregation_window = makeConst(element::i32, ov::Shape({0}), {0}); + auto rotated_block_indices = makeConst(element::i32, ov::Shape({0}), {0}); + auto rotation_deltas = makeConst(element::i32, ov::Shape{0}, {0}); + auto rotation_trig_lut = makeConst(element::f32, ov::Shape({0}), {0}); + auto xattention_threshold = makeConst(element::f32, ov::Shape({0}), {0}); + auto xattention_block_size = makeConst(element::i32, ov::Shape({}), {0}); + auto xattention_stride = makeConst(element::i32, ov::Shape({}), {0}); + auto adaptive_rkv_start_size = makeConst(element::i32, ov::Shape({}), MOCK_VALUE); + auto adaptive_rkv_evictable_sizes = makeConst(element::i32, ov::Shape({0}), {0}); + auto adaptive_rkv_diversity_block_set_indices = makeConst(element::i32, ov::Shape({0}), {0}); + auto adaptive_rkv_diversity_block_set_indices_begins = makeConst(element::i32, ov::Shape({0}), {0}); + + auto params = nodes_to_params({max_context_len, + block_indices_begins, + block_indices, + subsequence_begins, + past_lens, + value_cache_0, + key_cache_0, + input_ids, + position_ids, + token_type_ids_param}); + + auto Constant0 = makeConst(element::f32, ov::Shape({32000, 128}), MOCK_VALUE); + auto Unsqueeze_ids = makeOP({input_ids, 1}); + auto Convert3 = makeOP({Unsqueeze_ids}, {{"destination_type", "i32"}}); + auto Gather1 = makeOP({Constant0, Convert3, 0}, {{"batch_dims", 0}}); + auto Power0 = makeOP({Gather1, single_val(3, 2.0f)}, {{"auto_broadcast", "numpy"}}); + auto ReduceMean0 = makeOP({Power0, {-1}}, {{"keep_dims", true}}); + auto Add0 = makeOP({ReduceMean0, single_val(3, 1e-6f)}, {{"auto_broadcast", "numpy"}}); + auto Sqrt0 = makeOP({Add0}); + auto Divide0 = + makeOP({single_val(3, 1.0f), Sqrt0}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Multiply1 = makeOP({Gather1, Divide0}, {{"auto_broadcast", "numpy"}}); + auto Constant_w = makeConst(element::f32, ov::Shape({1, 1, 128}), MOCK_VALUE); + auto Multiply2 = makeOP({Constant_w, Multiply1}, {{"auto_broadcast", "numpy"}}); + + auto q_weight = makeConst(element::f32, ov::Shape({512, 128}), MOCK_VALUE); + auto MatMul_q = makeOP({Multiply2, q_weight}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Reshape_q = makeOP({MatMul_q, {0, 0, 4, 128}}, {{"special_zero", true}}); + auto Transpose_q = makeOP({Reshape_q, {0, 2, 1, 3}}); + auto Transpose_q2 = makeOP({Transpose_q, {0, 2, 1, 3}}); + auto Q_flat = makeOP({Transpose_q2, {0, -1}}, {{"special_zero", true}}); + + auto k_weight = makeConst(element::f32, ov::Shape({128, 128}), MOCK_VALUE); + auto MatMul_k = makeOP({Multiply2, k_weight}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Reshape_k = makeOP({MatMul_k, {0, 0, 1, 128}}, {{"special_zero", true}}); + auto Transpose_k = makeOP({Reshape_k, {0, 2, 1, 3}}); + auto Transpose_k2 = makeOP({Transpose_k, {0, 2, 1, 3}}); + auto K_flat = makeOP({Transpose_k2, {0, -1}}, {{"special_zero", true}}); + + auto MatMul_v = makeOP({Multiply2, makeConst(element::f32, ov::Shape({128, 128}), MOCK_VALUE)}, + {{"transpose_a", false}, {"transpose_b", true}}); + auto Reshape_v = makeOP({MatMul_v, {0, 0, 1, 128}}, {{"special_zero", true}}); + auto Transpose_v = makeOP({Reshape_v, {0, 2, 1, 3}}); + auto Transpose_v2 = makeOP({Transpose_v, {0, 2, 1, 3}}); + auto V_flat = makeOP({Transpose_v2, {0, -1}}, {{"special_zero", true}}); + + auto sw_neg = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1024.0f}); + auto Squeeze_sw = makeOP({sw_neg}, {{"allow_axis_skip", false}}); + auto Convert_sw = makeOP({Squeeze_sw}, {{"destination_type", "i32"}}); + auto sliding_window = makeOP({Convert_sw, -1}, {{"auto_broadcast", "numpy"}}); + + auto scale = v0::Constant::create(element::f32, Shape{}, {0.125f}); + auto alibi_slopes = v0::Constant::create(element::f32, Shape{0}, {}); + auto sinks = v0::Constant::create(element::f32, Shape{0, 0, 0, 0}, {}); + + auto token_type_ids_i32 = makeOP({token_type_ids_param}, {{"destination_type", "i32"}}); + + auto PA = std::make_shared( + OutputVector{Q_flat, + K_flat, + V_flat, + key_cache_0, + value_cache_0, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + score_aggregation_window, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, + xattention_threshold, + xattention_block_size, + xattention_stride, + sinks, + adaptive_rkv_start_size, + adaptive_rkv_evictable_sizes, + adaptive_rkv_diversity_block_set_indices, + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids_i32}); + + auto ShapeOf_v = makeOP({Transpose_v2}, {{"output_type", "i64"}}); + auto Gather_dim = makeOP({ShapeOf_v, -1, 0}, {{"batch_dims", 0}}); + auto Unsqueeze_dim = makeOP({Gather_dim, 0}); + auto pa_shape = makeOP({{0l}, {1l}, {-1l}, Unsqueeze_dim}, {{"axis", 0}}); + auto pa_reshape = makeOP({PA->output(0), pa_shape}, {{"special_zero", true}}); + auto pa_transpose = makeOP({pa_reshape, {0, 2, 1, 3}}); + + auto res = makeOP({pa_transpose}); + + model_ref = std::make_shared(OutputVector{res}, params); + + comparator.disable(FunctionsComparator::PRECISIONS); + disable_result_friendly_names_check(); + disable_rt_info_check(); + } +} + /* As there's often a need to cover specific model's architecutres in these tests, please, make sure you name the tests in the following manner: diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index 00824c47f275de..3e16be51ebaa4c 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -79,8 +79,8 @@ void PagedAttentionExtension::validate_and_infer_types() { OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types); NODE_VALIDATION_CHECK(this, - get_input_size() == 25, - "PagedAttensionExtension expects 25 inputs, but it has ", + get_input_size() == 26, + "PagedAttensionExtension expects 26 inputs, but it has ", get_input_size()); // format: Node*, input_idx, name, {rank_list}, {type_list} @@ -109,6 +109,7 @@ void PagedAttentionExtension::validate_and_infer_types() { input_check(this, 22, "adaptive_rkv_evictable_sizes", {1}, {element::i32}); input_check(this, 23, "adaptive_rkv_diversity_block_set_indices", {1}, {element::i32}); input_check(this, 24, "adaptive_rkv_diversity_block_set_indices_begins", {1}, {element::i32}); + input_check(this, 25, "token_type_ids", {1, 2}, {element::i32}); // value head_size may be not same with key auto out_ps = get_input_partial_shape(0); diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 2fa03edfe323dc..65c6eeba01583b 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -151,6 +151,11 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptrvalidate_and_infer_types(); + optional_model_wide_params["token_type_ids"] = token_type_ids_param; + } + std::shared_ptr position_ids; if (!get_parameter(model, "position_ids")) { position_ids = named_parameter(std::make_shared(element::i64, PartialShape{-1}), "position_ids"); diff --git a/src/core/tests/type_prop/paged_attention.cpp b/src/core/tests/type_prop/paged_attention.cpp index 5eb2d0772818e7..0911a4bf0a5e94 100644 --- a/src/core/tests/type_prop/paged_attention.cpp +++ b/src/core/tests/type_prop/paged_attention.cpp @@ -45,6 +45,8 @@ TEST(type_prop, paged_attention_static_eviction_per_block) { const auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(element::i32, PartialShape{5}); + const auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}); + ov::OutputVector args = {query, key, value, @@ -69,7 +71,8 @@ TEST(type_prop, paged_attention_static_eviction_per_block) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}; + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}; const auto op = std::make_shared(args); EXPECT_EQ(op->get_output_element_type(0), element::f32); @@ -109,6 +112,8 @@ TEST(type_prop, paged_attention_static_eviction_per_token) { const auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(element::i32, PartialShape{5}); + const auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}); + ov::OutputVector args = {query, key, value, @@ -133,7 +138,8 @@ TEST(type_prop, paged_attention_static_eviction_per_token) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}; + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}; const auto op = std::make_shared(args); EXPECT_EQ(op->get_output_element_type(0), element::f32); @@ -174,6 +180,8 @@ TEST(type_prop, paged_attention_dynamic_ranks_and_types) { const auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(element::dynamic, dyn); + const auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}); + ov::OutputVector args = {query, key, value, @@ -198,7 +206,8 @@ TEST(type_prop, paged_attention_dynamic_ranks_and_types) { adaptive_rkv_start_size, adaptive_rkv_evictable_sizes, adaptive_rkv_diversity_block_set_indices, - adaptive_rkv_diversity_block_set_indices_begins}; + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}; EXPECT_NO_THROW(std::ignore = std::make_shared(args)); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index a89dd072965dfd..6a0ecf74394f33 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -460,6 +460,54 @@ struct MHAHelper { size_t _sparse_mask_block_size = 0; bool _use_softmax_sparse_mask = false; + // Bidirectional attention for image token groups (e.g. Gemma3 VLM) + bool _has_image_tokens = false; // true only when token_type_ids input is provided + PlainTensor _token_type; // [total_batched_tokens], int32 — 0=text, 1=image + std::vector _image_group_end; // for image token i, the exclusive end of its group + + // Precompute image group boundaries from token_type_ids. + // For each image token, _image_group_end[i] = index past the last contiguous image token in the same group. + // For text tokens, _image_group_end[i] = -1. + void set_token_type(const PlainTensor& token_type, + const PlainTensor& subsequence_begins, + const PlainTensor& past_lens) { + _has_image_tokens = true; + _token_type = token_type; + auto total_tokens = static_cast(token_type.m_dims[0]); + _image_group_end.resize(total_tokens); + + auto seq_count = static_cast(past_lens.m_dims[0]); + for (int32_t seq = 0; seq < seq_count; seq++) { + auto seq_begin = subsequence_begins.ptr()[seq]; + auto seq_end = subsequence_begins.ptr()[seq + 1]; + + // Backward scan within this subsequence to find group ends + for (int32_t i = seq_end - 1; i >= seq_begin; i--) { + if (token_type.ptr()[i] == 1) { // image token + if (i + 1 < seq_end && token_type.ptr()[i + 1] == 1) { + _image_group_end[i] = _image_group_end[i + 1]; + } else { + _image_group_end[i] = i + 1; + } + } else { + _image_group_end[i] = -1; + } + } + } + } + + // Return the adjusted ncausal that extends to the image group end for image tokens. + // For text tokens, returns the original ncausal unchanged. + size_t get_ncausal(size_t q_global_idx, size_t default_ncausal, size_t cur_kv_len) const { + if (!_has_image_tokens || q_global_idx >= _image_group_end.size()) { + return default_ncausal; + } + if (_token_type.ptr()[q_global_idx] == 1) { + // extend ncausal to the end of the image group, capped by cur_kv_len + return std::min(static_cast(_image_group_end[q_global_idx]), cur_kv_len); + } + return default_ncausal; + } CpuParallelPtr _cpu_parallel; MHAHelper() { @@ -698,6 +746,8 @@ struct MHAHelper { // output_emb: [L, H * S] // qk_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] // wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] + // q_token_start: global token index of the first query token in this subsequence, used to + // look up per-token data when computing ncausal void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb, @@ -717,7 +767,8 @@ struct MHAHelper { const ScoreAggregationInfo* score_info_ptr, const PlainTensor& sinks, size_t batch_in_seq = 0, - const std::vector& sparse_attention_mask = {}) { + const std::vector& sparse_attention_mask = {}, + size_t q_token_start = 0) { auto q_start = q_blk * _block_size; auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; @@ -796,7 +847,7 @@ struct MHAHelper { for (size_t m = q_start; m < q_end; m++) { // apply attention mask & sofmax - auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1); + auto ncausal = get_ncausal(q_token_start + m, cur_kv_len - q_cnt + (m - q_start) + 1, cur_kv_len); auto* score = _weight.ptr(ithr, h - hq_beg, m - q_start); // dequantization of q matrix could be fused with _d_scale since softmax is done by row float revised_d_scale = @@ -965,7 +1016,8 @@ struct MHAHelper { const PlainTensor& alibi_slopes, float* score_output, size_t q_start_idx_score, - const ScoreAggregationInfo* score_info_ptr) { + const ScoreAggregationInfo* score_info_ptr, + size_t q_token_start = 0) { auto q_start = q_blk * _block_size; auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; @@ -1004,7 +1056,7 @@ struct MHAHelper { for (size_t m = q_start; m < q_end; m++) { // apply softmax in f32 precision - auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1); + auto ncausal = get_ncausal(q_token_start + m, cur_kv_len - q_cnt + (m - q_start) + 1, cur_kv_len); auto soft_in = _weight.ptr(ithr, h - hq_beg, m - q_start); auto score = _weight.ptr(ithr, h - hq_beg, m - q_start); PlainTensor f32_cvt; @@ -1108,7 +1160,8 @@ struct MHAHelper { size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output, - const PlainTensor& sinks) { + const PlainTensor& sinks, + size_t q_token_start = 0) { # if defined(OPENVINO_ARCH_X86_64) if (any_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) { _gemv->tile_config(); @@ -1159,13 +1212,14 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { // apply attention mask & sofmax + auto ncausal = get_ncausal(q_token_start + pq, cur_kv_len, cur_kv_len); float* score = _weight.ptr(ithr, h - hq_beg, pq); OPENVINO_DEBUG_ASSERT(score != nullptr, "PagedAttention: _weight buffer must be allocated"); float* alibi_lookup = nullptr; float alibi_slope = 0.F; if (alibi_slopes) { alibi_slope = alibi_slopes.ptr()[h]; - alibi_lookup = _alibi_lookup.ptr() + _alibi_lookup.m_dims[0] - cur_kv_len; + alibi_lookup = _alibi_lookup.ptr() + _alibi_lookup.m_dims[0] - ncausal; } float* sink = nullptr; if (sinks) { @@ -1173,10 +1227,10 @@ struct MHAHelper { } if (_sliding_window) { size_t start_idx = 0; - size_t new_causal = cur_kv_len; + size_t new_causal = ncausal; float* sw_alibi_lookup = nullptr; - if (cur_kv_len > _sliding_window) { - start_idx = cur_kv_len - _sliding_window; + if (ncausal > _sliding_window) { + start_idx = ncausal - _sliding_window; new_causal = _sliding_window; } attn_softmax_kernel(score + start_idx, @@ -1203,7 +1257,7 @@ struct MHAHelper { nullptr, nullptr, false, - cur_kv_len, + ncausal, cur_kv_len, ov::element::f32, ov::element::f32, @@ -1368,15 +1422,16 @@ struct MHAHelper { auto loop_softmax = [&](size_t b, size_t h, size_t pq) { auto cur_kv_len = static_cast(past_lens.ptr()[b]) + 1; - auto ncausal = cur_kv_len; - // apply attention mask & sofmax + auto q_token_start = static_cast(subsequence_begins.ptr()[b]); + auto ncausal = get_ncausal(q_token_start + pq, cur_kv_len, cur_kv_len); + // apply attention mask & sofmax float* score = _weight_bhl.ptr(b, h, pq); OPENVINO_DEBUG_ASSERT(score != nullptr, "PagedAttention: _weight_bhl buffer must be allocated"); float* alibi_lookup = nullptr; float alibi_slope = 0.F; if (alibi_slopes) { alibi_slope = alibi_slopes.ptr()[h]; - alibi_lookup = _alibi_lookup.ptr() + _alibi_lookup.m_dims[0] - cur_kv_len; + alibi_lookup = _alibi_lookup.ptr() + _alibi_lookup.m_dims[0] - ncausal; } float* sink = nullptr; if (sinks) { @@ -1690,7 +1745,8 @@ struct MHA { cur_kv_len, alibi_slopes, score_output, - sinks); + sinks, + static_cast(batch_in_token)); } else { const auto batch_in_reorder = item.batch_in_reorder; const auto q_blk = item.q_block_id; @@ -1742,7 +1798,8 @@ struct MHA { alibi_slopes, score_output, q_start_idx_score, - score_info_ptr); + score_info_ptr, + static_cast(batch_in_token)); } else { _helper.exec_kernel_multiple( sub_query, @@ -1763,7 +1820,10 @@ struct MHA { score_output, q_start_idx_score, score_info_ptr, - PlainTensor()); + PlainTensor(), + 0, + {}, + static_cast(batch_in_token)); } # else _helper.exec_kernel_multiple( @@ -1787,7 +1847,8 @@ struct MHA { score_info_ptr, sinks, batch_in_seq, - sparse_attention_mask); + sparse_attention_mask, + static_cast(batch_in_token)); # endif } }); @@ -1920,7 +1981,8 @@ struct AttentionExecutor : public PagedAttentionExecutor { PlainTensor& output_emb, PlainTensor& output_score, std::vector& sparse_attention_mask, - PlainTensor& output_arkv_similarity) { + PlainTensor& output_arkv_similarity, + PlainTensor& token_type_ids) { q.reset(inputs[ID_Q]); // [B_token, H * S] k.reset(inputs[ID_K]); v.reset(inputs[ID_V]); @@ -1942,7 +2004,7 @@ struct AttentionExecutor : public PagedAttentionExecutor { } size_t inputs_size = inputs.size(); - OPENVINO_ASSERT(inputs_size == 25); + OPENVINO_ASSERT(inputs_size == 26); if (!inputs[ID_ROTATED_BLOCK_INDICES]->getShape().hasZeroDims()) { rotated_block_indices.reset(inputs[ID_ROTATED_BLOCK_INDICES]); // [num_blocks] } @@ -1980,6 +2042,14 @@ struct AttentionExecutor : public PagedAttentionExecutor { output_arkv_similarity.reset(outputs[2]); } + if (!inputs[ID_TOKEN_TYPE_IDS]->getShape().hasZeroDims()) { + token_type_ids.reset(inputs[ID_TOKEN_TYPE_IDS]); + if (token_type_ids.m_rank == 2) { + auto total = token_type_ids.m_dims[0] * token_type_ids.m_dims[1]; + token_type_ids = token_type_ids.reshape({total}); + } + } + output_emb.reset(outputs[0]); if (outputs.size() >= 2) { output_score.reset(outputs[1]); @@ -2132,6 +2202,10 @@ struct AttentionExecutor : public PagedAttentionExecutor { } } + if (token_type_ids) { + token_type_ids.assert_dims({B_token}); + } + output_emb.assert_dims({B_token, H * SV}); output_emb = output_emb.reshape({B_token, 1, H * SV}); @@ -2270,6 +2344,8 @@ struct AttentionExecutor : public PagedAttentionExecutor { PlainTensor output_score; PlainTensor output_arkv_similarity; + PlainTensor token_type_ids; + std::vector sparse_attention_mask; // Each vector element corresponds to a batch, and each PlainTensor corresponds to a // batch, with shape: [H, q_blocks, k_blocks], type: bool @@ -2304,7 +2380,15 @@ struct AttentionExecutor : public PagedAttentionExecutor { output_emb, output_score, sparse_attention_mask, - output_arkv_similarity); + output_arkv_similarity, + token_type_ids); + + if (token_type_ids) { + _helper.set_token_type(token_type_ids, subsequence_begins, past_lens); + } else { + _helper._token_type = PlainTensor(); + _helper._image_group_end.clear(); + } if (rotated_block_indices) { // Rotate kv cache currently doesn't support quantized cache. diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp index 0a9d2517fd515d..e68b959b14c949 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp @@ -46,6 +46,7 @@ struct PagedAttentionExecutor { static const size_t ID_ADAPTIVE_RKV_EVICTABLE_SIZES = 22; // [B_seq], int32 static const size_t ID_ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES = 23; // [num_adaptive_rkv_blocks], int32 static const size_t ID_ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES_BEGINS = 24; // [B_seq + 1], int32 + static const size_t ID_TOKEN_TYPE_IDS = 25; // [B_token | 0] or [1, B_token], i32 virtual void execute(const std::vector& inputs, std::vector outputs) = 0; virtual ~PagedAttentionExecutor() = default; diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 5079e7d136e6a6..468398cbea81ba 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -99,7 +99,7 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { creatorsMap.at(LayoutType::ncsp) ->createSharedDesc(rtPrecision, getInputShapeAtPort(PagedAttentionExecutor::ID_V))); - CPU_NODE_ASSERT(orgInputNumber == 25U, "The input number of PagedAttention should be 25."); + CPU_NODE_ASSERT(orgInputNumber == 26U, "The input number of PagedAttention should be 26."); // kvcache, float, [] auto past_key_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto past_value_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); @@ -207,6 +207,11 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES_BEGINS))); + // token_type_ids, i32, [B_token | 0] or [1, B_token] + config.inConfs[PagedAttentionExecutor::ID_TOKEN_TYPE_IDS].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_TOKEN_TYPE_IDS))); + config.outConfs[2].setMemDesc( creatorsMap.at(LayoutType::ncsp)->createSharedDesc(ov::element::f32, getOutputShapeAtPort(2))); diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn.cpp index 5df3da81853e64..b6f7cd6df7362f 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn.cpp @@ -159,6 +159,7 @@ class PagedAttnTestBase : public testing::WithParamInterface(ov::element::i32, Shape{0}, std::vector{0}); auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}, std::vector{}); ParameterVector params = {q, k, v, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins}; OutputVector paged_attn_inputs = {q, @@ -185,7 +186,8 @@ class PagedAttnTestBase : public testing::WithParamInterface(paged_attn_inputs); diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_score.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_score.cpp index bc2560c136b73c..8bbb1776e786ba 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_score.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_score.cpp @@ -138,6 +138,7 @@ class PagedAttnScoreTest : public testing::WithParamInterface(ov::element::i32, Shape{0}, std::vector{0}); auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + auto token_type_ids = std::make_shared(ov::element::i32, ov::Shape{0}, std::vector{}); ParameterVector params = {q, k, v, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins}; auto paged_attn = std::make_shared(OutputVector{q, @@ -164,7 +165,8 @@ class PagedAttnScoreTest : public testing::WithParamInterfaceget_rt_info()["num_k_heads"] = head_num; paged_attn->get_rt_info()["k_head_size"] = head_size; paged_attn->get_rt_info()["num_v_heads"] = head_num; diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_token_type.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_token_type.cpp new file mode 100644 index 00000000000000..e8f1eb52680eca --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/paged_attn_token_type.cpp @@ -0,0 +1,433 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + + +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" +#include "common_test_utils/node_builders/constant.hpp" +#include "internal_properties.hpp" +#include "openvino/core/type/float16.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/paged_attention.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "utils/cpu_test_utils.hpp" +#include "utils/general_utils.h" + +using namespace ov::test; +using namespace CPUTestUtils; +using namespace ov::op; + +namespace ov { +namespace test { + +struct TokenTypePattern { + std::string name; + std::vector types; // 0=text, 1=image +}; + +using PagedAttnTokenTypeParams = std::tuple; + +class PagedAttnTokenTypeTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + const auto& [inType, head_size, head_num, pattern] = obj.param; + std::ostringstream result; + result << "Prc=" << inType << "_"; + result << "HS=" << head_size << "_"; + result << "HN=" << head_num << "_"; + result << "Pattern=" << pattern.name; + return result.str(); + } + + static std::shared_ptr make_param(const PartialShape& pshape, + element::Type element_type, + const std::string& name) { + auto param = std::make_shared(element_type, pshape); + param->set_friendly_name(name); + param->get_output_tensor(0).set_names({name}); + return param; + } + + std::shared_ptr get_pa_model(ov::element::Type data_type, + ov::Dimension::value_type head_size, + ov::Dimension::value_type head_num) { + auto q = make_param(PartialShape{ov::Dimension::dynamic(), ov::Dimension::dynamic()}, data_type, "q"); + auto k = make_param(PartialShape{ov::Dimension::dynamic(), head_num * head_size}, data_type, "k"); + auto v = make_param(PartialShape{ov::Dimension::dynamic(), head_num * head_size}, data_type, "v"); + auto key_cache = make_param(PartialShape{ov::Dimension::dynamic(), 32, ov::Dimension::dynamic()}, + ov::element::dynamic, "key_cache.0"); + auto value_cache = make_param(PartialShape{ov::Dimension::dynamic(), 32, ov::Dimension::dynamic()}, + ov::element::dynamic, "value_cache.0"); + auto past_lens = make_param(PartialShape{ov::Dimension::dynamic()}, ov::element::i32, "past_lens"); + auto subsequence_begins = make_param(PartialShape{ov::Dimension::dynamic()}, ov::element::i32, "subsequence_begins"); + auto block_indices = make_param(PartialShape{ov::Dimension::dynamic()}, ov::element::i32, "block_indices"); + auto block_indices_begins = make_param(PartialShape{ov::Dimension::dynamic()}, ov::element::i32, "block_indices_begins"); + + float scale_value = 1.0f / std::sqrt(static_cast(head_size)); + auto scale = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale_value}); + auto sliding_window = std::make_shared(ov::element::i32, Shape{}, std::vector{0}); + auto alibi_slopes = std::make_shared(ov::element::f32, Shape{0}, std::vector{}); + auto max_context_len = std::make_shared(ov::element::i32, Shape{}, std::vector{1024}); + auto score_aggregation_window = std::make_shared(ov::element::i32, Shape{}, std::vector{0}); + auto rotated_block_indices = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + auto rotation_deltas = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + auto rotation_trig_lut = std::make_shared(ov::element::f32, Shape{0}, std::vector{0}); + auto xattention_threshold = std::make_shared(ov::element::f32, Shape{0}, std::vector{0}); + auto xattention_block_size = std::make_shared(ov::element::i32, Shape{}, std::vector{64}); + auto xattention_stride = std::make_shared(ov::element::i32, Shape{}, std::vector{8}); + auto sinks = std::static_pointer_cast( + ov::test::utils::make_constant(data_type, Shape{0})); + auto adaptive_rkv_start_size = std::make_shared(ov::element::i32, Shape{}, std::vector{0}); + auto adaptive_rkv_evictable_sizes = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + auto adaptive_rkv_diversity_block_set_indices = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + auto adaptive_rkv_diversity_block_set_indices_begins = std::make_shared(ov::element::i32, Shape{0}, std::vector{0}); + + auto token_type_ids = make_param(PartialShape{ov::Dimension::dynamic()}, ov::element::i32, "token_type_ids"); + + ParameterVector params = {q, k, v, key_cache, value_cache, past_lens, + subsequence_begins, block_indices, block_indices_begins, + token_type_ids}; + + OutputVector pa_inputs = {q, k, v, key_cache, value_cache, + past_lens, subsequence_begins, block_indices, block_indices_begins, + scale, sliding_window, alibi_slopes, max_context_len, + score_aggregation_window, rotated_block_indices, rotation_deltas, + rotation_trig_lut, xattention_threshold, xattention_block_size, + xattention_stride, sinks, adaptive_rkv_start_size, + adaptive_rkv_evictable_sizes, + adaptive_rkv_diversity_block_set_indices, + adaptive_rkv_diversity_block_set_indices_begins, + token_type_ids}; + + OPENVINO_ASSERT(pa_inputs.size() == 26); + + auto paged_attn = std::make_shared(pa_inputs); + paged_attn->get_rt_info()["num_k_heads"] = head_num; + paged_attn->get_rt_info()["k_head_size"] = head_size; + paged_attn->get_rt_info()["num_v_heads"] = head_num; + paged_attn->get_rt_info()["v_head_size"] = head_size; + + return std::make_shared(OutputVector{paged_attn}, params); + } + + template + static void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + const float idx = static_cast(n - 1 - i); + *first++ = value + stride * static_cast(idx); + } + } + + struct RunResult { + ov::Tensor output; + }; + + RunResult run_pa_with_token_types(std::shared_ptr model, + ov::element::Type data_type, + size_t seq_len, + size_t head_size, + size_t head_num, + const std::vector& token_types) { + OPENVINO_ASSERT(token_types.size() == seq_len); + + configuration[ov::hint::inference_precision.name()] = ov::element::f32; + function = model; + compile_model(); + auto infer_request = compiledModel.create_infer_request(); + + // Determine cache precision from compiled model + ov::Tensor key_cache_tensor, value_cache_tensor; + for (const auto& input : compiledModel.inputs()) { + for (auto& name : input.get_names()) { + auto cache_precision = input.get_element_type(); + const size_t block_nums = 1024 / 32; + ov::PartialShape pshape; + if (name.find("key_cache.") == 0) { + pshape = input.get_partial_shape(); + pshape[0] = block_nums; + key_cache_tensor = ov::Tensor(cache_precision, pshape.get_shape()); + } else if (name.find("value_cache.") == 0) { + pshape = input.get_partial_shape(); + pshape[0] = block_nums; + value_cache_tensor = ov::Tensor(cache_precision, pshape.get_shape()); + } + } + } + + auto params = model->get_parameters(); + size_t hidden_dim = head_num * head_size; + + // q, k, v tensors [seq_len, hidden_dim] + auto fill_tensor = [](ov::Tensor& t, float base, float stride) { + auto* p = t.data(); + for (size_t i = 0; i < t.get_size(); i++) { + p[i] = base + stride * static_cast(i % 17); // pseudo-random repeating pattern + } + }; + + ov::Tensor q_tensor(data_type, {seq_len, hidden_dim}); + ov::Tensor k_tensor(data_type, {seq_len, hidden_dim}); + ov::Tensor v_tensor(data_type, {seq_len, hidden_dim}); + + if (data_type == ov::element::f32) { + fill_tensor(q_tensor, 0.1f, 0.01f); + fill_tensor(k_tensor, 0.2f, 0.01f); + fill_tensor(v_tensor, 0.3f, 0.01f); + } + + // Prefill: past_lens=0, single sequence + size_t batch_size = 1; + int32_t total_blocks = static_cast((seq_len + 31) / 32); + + ov::Tensor past_lens(ov::element::i32, {batch_size}); + ov::Tensor subsequence_begins(ov::element::i32, {batch_size + 1}); + ov::Tensor block_indices(ov::element::i32, {static_cast(total_blocks)}); + ov::Tensor block_indices_begins(ov::element::i32, {batch_size + 1}); + + past_lens.data()[0] = 0; + subsequence_begins.data()[0] = 0; + subsequence_begins.data()[1] = static_cast(seq_len); + block_indices_begins.data()[0] = 0; + block_indices_begins.data()[1] = total_blocks; + for (int32_t i = 0; i < total_blocks; i++) { + block_indices.data()[i] = i; + } + + // token_type_ids + ov::Tensor token_type_tensor(ov::element::i32, {seq_len}); + std::memcpy(token_type_tensor.data(), token_types.data(), seq_len * sizeof(int32_t)); + + for (auto& param : params) { + auto name = param->get_friendly_name(); + if (name == "q") infer_request.set_tensor(param, q_tensor); + else if (name == "k") infer_request.set_tensor(param, k_tensor); + else if (name == "v") infer_request.set_tensor(param, v_tensor); + else if (name == "key_cache.0") infer_request.set_tensor(param, key_cache_tensor); + else if (name == "value_cache.0") infer_request.set_tensor(param, value_cache_tensor); + else if (name == "past_lens") infer_request.set_tensor(param, past_lens); + else if (name == "subsequence_begins") infer_request.set_tensor(param, subsequence_begins); + else if (name == "block_indices") infer_request.set_tensor(param, block_indices); + else if (name == "block_indices_begins") infer_request.set_tensor(param, block_indices_begins); + else if (name == "token_type_ids") infer_request.set_tensor(param, token_type_tensor); + } + + infer_request.infer(); + + auto output = infer_request.get_output_tensor(0); + ov::Tensor output_copy{output.get_element_type(), output.get_shape()}; + output.copy_to(output_copy); + return {output_copy}; + } +}; + +// With all-zero token_type_ids (text-only), causal masking must hold: +// output at position i depends only on tokens 0..i, so a shorter prefix +// should produce identical outputs for the shared positions. +TEST_P(PagedAttnTokenTypeTest, AllTextIsCausal) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + const auto& [inType, head_size, head_num, pattern] = this->GetParam(); + if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + + targetDevice = ov::test::utils::DEVICE_CPU; + + size_t full_len = pattern.types.size(); + size_t prefix_len = full_len / 2; // e.g. 5 out of 10 + OPENVINO_ASSERT(prefix_len > 0); + + std::vector all_text_full(full_len, 0); + std::vector all_text_prefix(prefix_len, 0); + + auto model_full = get_pa_model(inType, head_size, head_num); + auto result_full = run_pa_with_token_types(model_full, inType, full_len, head_size, head_num, all_text_full); + + auto model_prefix = get_pa_model(inType, head_size, head_num); + auto result_prefix = run_pa_with_token_types(model_prefix, inType, prefix_len, head_size, head_num, all_text_prefix); + + // First prefix_len positions of the full run should match the prefix run exactly + size_t hidden_dim = head_num * head_size; + auto* full_data = result_full.output.data(); + auto* prefix_data = result_prefix.output.data(); + + for (size_t pos = 0; pos < prefix_len; pos++) { + for (size_t d = 0; d < hidden_dim; d++) { + float diff = std::abs(full_data[pos * hidden_dim + d] - prefix_data[pos * hidden_dim + d]); + EXPECT_LT(diff, 1e-5f) + << "Causal masking violated: position " << pos << " dim " << d + << " differs between seq_len=" << full_len << " and seq_len=" << prefix_len + << ", diff=" << diff; + } + } +} + + +TEST_P(PagedAttnTokenTypeTest, ImageTokensDifferFromCausal) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + const auto& [inType, head_size, head_num, pattern] = this->GetParam(); + if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + + targetDevice = ov::test::utils::DEVICE_CPU; + + size_t seq_len = pattern.types.size(); + + auto pa_model = get_pa_model(inType, head_size, head_num); + auto result_bidir = run_pa_with_token_types(pa_model, inType, seq_len, head_size, head_num, pattern.types); + + std::vector all_causal(seq_len, 0); + auto pa_model_causal = get_pa_model(inType, head_size, head_num); + auto result_causal = run_pa_with_token_types(pa_model_causal, inType, seq_len, head_size, head_num, all_causal); + + // Image tokens should have different output compared to causal-only + size_t hidden_dim = head_num * head_size; + auto* bidir_data = result_bidir.output.data(); + auto* causal_data = result_causal.output.data(); + + bool any_image_differs = false; + for (size_t pos = 0; pos < seq_len; pos++) { + if (pattern.types[pos] != 1) continue; // Skip text tokens + // Only image tokens that are NOT the last in their group will differ, + // because only they gain access to future KV positions. + for (size_t d = 0; d < hidden_dim; d++) { + float diff = std::abs(bidir_data[pos * hidden_dim + d] - causal_data[pos * hidden_dim + d]); + if (diff > 1e-5f) { + any_image_differs = true; + break; + } + } + if (any_image_differs) break; + } + EXPECT_TRUE(any_image_differs) + << "Expected image tokens with bidirectional attention to produce different output than causal-only"; +} + +// Text tokens outside image groups should be unaffected by token_type_ids +TEST_P(PagedAttnTokenTypeTest, TextTokensUnaffected) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + const auto& [inType, head_size, head_num, pattern] = this->GetParam(); + if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + + targetDevice = ov::test::utils::DEVICE_CPU; + + size_t seq_len = pattern.types.size(); + + auto pa_model = get_pa_model(inType, head_size, head_num); + auto result_bidir = run_pa_with_token_types(pa_model, inType, seq_len, head_size, head_num, pattern.types); + + std::vector all_causal(seq_len, 0); + auto pa_model_causal = get_pa_model(inType, head_size, head_num); + auto result_causal = run_pa_with_token_types(pa_model_causal, inType, seq_len, head_size, head_num, all_causal); + + // Text tokens BEFORE the first image group should have identical output + size_t hidden_dim = head_num * head_size; + auto* bidir_data = result_bidir.output.data(); + auto* causal_data = result_causal.output.data(); + + size_t first_image_pos = seq_len; + for (size_t i = 0; i < seq_len; i++) { + if (pattern.types[i] == 1) { first_image_pos = i; break; } + } + + for (size_t pos = 0; pos < first_image_pos; pos++) { + for (size_t d = 0; d < hidden_dim; d++) { + float diff = std::abs(bidir_data[pos * hidden_dim + d] - causal_data[pos * hidden_dim + d]); + EXPECT_LT(diff, 1e-5f) + << "Text token at position " << pos << " dim " << d + << " should be unaffected by token_type_ids, but diff=" << diff; + } + } +} + + +// Text tokens after the last image group should match causal baseline +TEST_P(PagedAttnTokenTypeTest, PostImageTextIsCausal) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + const auto& [inType, head_size, head_num, pattern] = this->GetParam(); + if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + + targetDevice = ov::test::utils::DEVICE_CPU; + + size_t seq_len = pattern.types.size(); + + auto pa_model = get_pa_model(inType, head_size, head_num); + auto result_bidir = run_pa_with_token_types(pa_model, inType, seq_len, head_size, head_num, pattern.types); + + std::vector all_causal(seq_len, 0); + auto pa_model_causal = get_pa_model(inType, head_size, head_num); + auto result_causal = run_pa_with_token_types(pa_model_causal, inType, seq_len, head_size, head_num, all_causal); + + size_t hidden_dim = head_num * head_size; + auto* bidir_data = result_bidir.output.data(); + auto* causal_data = result_causal.output.data(); + + size_t last_image_pos = 0; + for (size_t i = 0; i < seq_len; i++) { + if (pattern.types[i] == 1) last_image_pos = i; + } + + // Text tokens after the last image group should match causal baseline + for (size_t pos = last_image_pos + 1; pos < seq_len; pos++) { + ASSERT_EQ(pattern.types[pos], 0) << "Expected text token at position " << pos; + for (size_t d = 0; d < hidden_dim; d++) { + float diff = std::abs(bidir_data[pos * hidden_dim + d] - causal_data[pos * hidden_dim + d]); + EXPECT_LT(diff, 1e-5f) + << "Post-image text token at position " << pos << " dim " << d + << " should match causal baseline, but diff=" << diff; + } + } +} + +namespace { + +const std::vector token_type_patterns = { + // Symmetric: text + centered image group + text + {"centered_image", {0, 0, 0, 1, 1, 1, 1, 0, 0, 0}}, + + // Image group near the start, more trailing text + {"early_image", {0, 1, 1, 1, 1, 0, 0, 0, 0, 0}}, + + // Image group near the end, more leading text + {"late_image", {0, 0, 0, 0, 0, 1, 1, 1, 1, 0}}, + + // Large image group — almost all image, minimal text framing + {"large_image", {0, 1, 1, 1, 1, 1, 1, 1, 1, 0}}, + + // Two separate image groups with text between and after + {"two_image_groups", {0, 1, 1, 0, 0, 0, 1, 1, 1, 0}}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_PagedAttnTokenType, + PagedAttnTokenTypeTest, + ::testing::Combine(::testing::Values(ElementType::f32), + ::testing::Values(64), // head_size + ::testing::Values(8), // head_num + ::testing::ValuesIn(token_type_patterns)), + PagedAttnTokenTypeTest::getTestCaseName); + +} // namespace +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index ba320e21af79b5..c9ad9eb20cc134 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -39,6 +39,7 @@ struct paged_attention : public primitive_base { ADAPTIVE_RKV_EVICTABLE_SIZES = 22, ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES = 23, ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES_BEGINS = 24, + TOKEN_TYPE_IDS = 25, }; static constexpr size_t block_size = 16; @@ -49,7 +50,7 @@ struct paged_attention : public primitive_base { paged_attention(const primitive_id& id, const std::vector& inputs) : primitive_base(id, inputs) { - OPENVINO_ASSERT((inputs.size() == 25), + OPENVINO_ASSERT((inputs.size() == 26), "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); } diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index f894c55265fef9..73362c4b365ebf 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -21,7 +21,7 @@ using PagedAttentionExtension = ov::op::PagedAttentionExtension; namespace ov::intel_gpu { static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared_ptr& op) { - validate_inputs_count(op, {25}); + validate_inputs_count(op, {26}); auto inputs = p.GetInputInfo(op); auto prim = cldnn::paged_attention(layer_type_name_ID(op), inputs); diff --git a/src/plugins/intel_gpu/tests/unit/dynamic_execution/update_shape_test.cpp b/src/plugins/intel_gpu/tests/unit/dynamic_execution/update_shape_test.cpp index 21503a933b3d04..192e16a5712ddb 100644 --- a/src/plugins/intel_gpu/tests/unit/dynamic_execution/update_shape_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/dynamic_execution/update_shape_test.cpp @@ -175,6 +175,10 @@ TEST(update_shape_test, max_context_len_shapeof_subgraph) { auto adaptive_rkv_diversity_block_set_indices_begins_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; auto adaptive_rkv_diversity_block_set_indices_begins_mem = engine.allocate_memory(adaptive_rkv_diversity_block_set_indices_begins_layout); + auto token_type_ids_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; + auto token_type_ids_mem = engine.allocate_memory(token_type_ids_layout); + set_values(token_type_ids_mem, {0}); + std::vector pa_inputs = {input_info("query"), input_info("key"), input_info("value"), @@ -200,6 +204,7 @@ TEST(update_shape_test, max_context_len_shapeof_subgraph) { input_info("adaptive_rkv_evictable_sizes"), input_info("adaptive_rkv_diversity_block_set_indices"), input_info("adaptive_rkv_diversity_block_set_indices_begins"), + input_info("token_type_ids"), }; auto pa_prim = paged_attention("paged_attention", pa_inputs); @@ -240,6 +245,7 @@ TEST(update_shape_test, max_context_len_shapeof_subgraph) { topology.add(input_layout("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes_layout)); topology.add(input_layout("adaptive_rkv_diversity_block_set_indices_begins", adaptive_rkv_diversity_block_set_indices_begins_layout)); topology.add(input_layout("adaptive_rkv_diversity_block_set_indices", adaptive_rkv_diversity_block_set_indices_layout)); + topology.add(input_layout("token_type_ids", token_type_ids_layout)); topology.add(data("const_one", const_one_mem)); topology.add(shape_of("shape_of", input_info("input_data"), data_types::i32)); topology.add(gather("gather", input_info("shape_of"), input_info("const_one"), 0, 1, ov::Shape{})); @@ -276,6 +282,7 @@ TEST(update_shape_test, max_context_len_shapeof_subgraph) { network.set_input_data("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes_mem); network.set_input_data("adaptive_rkv_diversity_block_set_indices_begins", adaptive_rkv_diversity_block_set_indices_begins_mem); network.set_input_data("adaptive_rkv_diversity_block_set_indices", adaptive_rkv_diversity_block_set_indices_mem); + network.set_input_data("token_type_ids", token_type_ids_mem); // Set original max_context_len value auto max_context_len_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; @@ -319,6 +326,7 @@ TEST(update_shape_test, max_context_len_shapeof_subgraph) { network.set_input_data("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes_mem); network.set_input_data("adaptive_rkv_diversity_block_set_indices_begins", adaptive_rkv_diversity_block_set_indices_begins_mem); network.set_input_data("adaptive_rkv_diversity_block_set_indices", adaptive_rkv_diversity_block_set_indices_mem); + network.set_input_data("token_type_ids", token_type_ids_mem); // Update max_context_len value, which should be taken into account in shape recalculation for broadcast set_values(max_context_len_mem, {8}); diff --git a/src/plugins/intel_gpu/tests/unit/passes/mark_shape_of_subgraphs_test.cpp b/src/plugins/intel_gpu/tests/unit/passes/mark_shape_of_subgraphs_test.cpp index c249c8e852f4b6..0a3054b08d9cdf 100644 --- a/src/plugins/intel_gpu/tests/unit/passes/mark_shape_of_subgraphs_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/passes/mark_shape_of_subgraphs_test.cpp @@ -471,6 +471,7 @@ TEST(mark_shape_of_subgraphs, paged_attention_max_context_len_input) { auto adaptive_rkv_evictable_sizes_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx}; auto adaptive_rkv_diversity_block_set_indices_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx}; auto adaptive_rkv_diversity_block_set_indices_begins_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx}; + auto token_type_ids_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; std::vector pa_inputs = {input_info("query"), input_info("key"), @@ -496,7 +497,8 @@ TEST(mark_shape_of_subgraphs, paged_attention_max_context_len_input) { input_info("adaptive_rkv_start_size"), input_info("adaptive_rkv_evictable_sizes"), input_info("adaptive_rkv_diversity_block_set_indices"), - input_info("adaptive_rkv_diversity_block_set_indices_begins") + input_info("adaptive_rkv_diversity_block_set_indices_begins"), + input_info("token_type_ids") }; auto pa_prim = paged_attention("paged_attention", pa_inputs); @@ -536,6 +538,7 @@ TEST(mark_shape_of_subgraphs, paged_attention_max_context_len_input) { topology.add(input_layout("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes_layout)); topology.add(input_layout("adaptive_rkv_diversity_block_set_indices", adaptive_rkv_diversity_block_set_indices_layout)); topology.add(input_layout("adaptive_rkv_diversity_block_set_indices_begins", adaptive_rkv_diversity_block_set_indices_begins_layout)); + topology.add(input_layout("token_type_ids", token_type_ids_layout)); topology.add(input_layout("input", input_layout_dynamic)); topology.add(data("target_shape", target_shape)); topology.add(data("subtract_one", subtract_one)); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 589f2d3d84019a..66889a205415ea 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -588,6 +588,11 @@ struct PagedAttentionManager { return get_memory_from_vec(adaptive_rkv_diversity_block_set_indices_begins); } + memory::ptr get_token_type_ids_memory() { + std::vector token_type_ids = { 0 }; + return get_memory_from_vec(token_type_ids); + } + float get_default_scale() { return static_cast(1.f / std::sqrt(k_head_size)); } @@ -1465,6 +1470,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { auto adaptive_rkv_evictable_sizes_mem = pam.get_adaptive_rkv_evictable_sizes_memory(); auto adaptive_rkv_diversity_block_set_indices_mem = pam.get_adaptive_rkv_diversity_block_set_indices_memory(); auto adaptive_rkv_diversity_block_set_indices_begins_mem = pam.get_adaptive_rkv_diversity_block_set_indices_begins_memory(); + auto token_type_ids_mem = pam.get_token_type_ids_memory(); auto query_layout = query_mem->get_layout(); auto key_layout = key_mem->get_layout(); @@ -1491,6 +1497,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { auto adaptive_rkv_evictable_sizes_layout = adaptive_rkv_evictable_sizes_mem->get_layout(); auto adaptive_rkv_diversity_block_set_indices_layout = adaptive_rkv_diversity_block_set_indices_mem->get_layout(); auto adaptive_rkv_diversity_block_set_indices_begins_layout = adaptive_rkv_diversity_block_set_indices_begins_mem->get_layout(); + auto token_type_ids_layout = token_type_ids_mem->get_layout(); // make layouts dynamic query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); @@ -1578,6 +1585,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { input_info("adaptive_rkv_evictable_sizes"), input_info("adaptive_rkv_diversity_block_set_indices"), input_info("adaptive_rkv_diversity_block_set_indices_begins"), + input_info("token_type_ids"), }; auto pa_prim = paged_attention("paged_attention", pa_inputs); @@ -1649,6 +1657,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { topology.add(input_layout("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes_layout)); topology.add(input_layout("adaptive_rkv_diversity_block_set_indices", adaptive_rkv_diversity_block_set_indices_layout)); topology.add(input_layout("adaptive_rkv_diversity_block_set_indices_begins", adaptive_rkv_diversity_block_set_indices_begins_layout)); + topology.add(input_layout("token_type_ids", token_type_ids_layout)); } ExecutionConfig config = get_test_default_config(get_test_engine()); @@ -1683,6 +1692,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { network->set_input_data("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes_mem); network->set_input_data("adaptive_rkv_diversity_block_set_indices", adaptive_rkv_diversity_block_set_indices_mem); network->set_input_data("adaptive_rkv_diversity_block_set_indices_begins", adaptive_rkv_diversity_block_set_indices_begins_mem); + network->set_input_data("token_type_ids", token_type_ids_mem); auto outputs = network->execute();