Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1abe8ec
WIP
p-wysocki Feb 9, 2026
34763dd
WIP
p-wysocki Feb 9, 2026
e7f8238
WIP
p-wysocki Feb 9, 2026
daef794
WIP
p-wysocki Feb 12, 2026
fdb3a73
Add tests
p-wysocki Feb 12, 2026
829c430
initial clenaup
p-wysocki Feb 12, 2026
a04d165
Set input as optional
p-wysocki Feb 12, 2026
96188fb
Correct tests
p-wysocki Feb 12, 2026
4d9f607
Remove reshape from graph
p-wysocki Feb 12, 2026
1f6a6d1
Remove debug prints
p-wysocki Feb 12, 2026
2bf68b5
Clenaup
p-wysocki Feb 13, 2026
bcbb855
Merge branch 'master' into attn_idea_2
p-wysocki Feb 13, 2026
ed3374d
Sliding window working
p-wysocki Feb 18, 2026
0fd5001
Move sw to gptoss logic
p-wysocki Feb 18, 2026
5fbbf5e
Working, with debug prints
p-wysocki Feb 18, 2026
81ab320
Cleanup
p-wysocki Feb 18, 2026
4e0d5ac
Merge branch 'attn_idea_2' of https://github.com/p-wysocki/openvino i…
p-wysocki Feb 18, 2026
5f5af24
Cleanup
p-wysocki Feb 18, 2026
362cb80
update copyright
p-wysocki Feb 18, 2026
7841c70
Fix transformation tests, add new one
p-wysocki Feb 18, 2026
9810dc3
Fix convert input tests
p-wysocki Feb 18, 2026
890804b
Fix clang
p-wysocki Feb 18, 2026
6a62dda
Fix smoke tests
p-wysocki Feb 18, 2026
dfc6e1f
Fix smoke test
p-wysocki Feb 18, 2026
a412f6c
Update GPU input count
p-wysocki Feb 25, 2026
810130d
CR
p-wysocki Mar 2, 2026
acbd73e
Add token_type_ids to gemma only
p-wysocki Mar 2, 2026
2da2303
Fix gpu test
p-wysocki Mar 2, 2026
03e9935
Merge branch 'master' into attn_idea_2
p-wysocki Mar 10, 2026
e2347e1
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Mar 11, 2026
7176099
Fix conflict issues
p-wysocki Mar 11, 2026
af07897
Apply CR
p-wysocki Mar 12, 2026
c99b9c5
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(
// Shared flag to track whether the model is Gemma3, set when any layer matches
// the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check,
// this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids).
auto is_gptoss_gemma3 = std::make_shared<bool>(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this variable inside the callback?

bool is_gemma3 = false;

ov::matcher_pass_callback callback = [=,
&is_gemma3,
&kv_parameters,
&model_wide_params,
&block_indices_inputs_for_each_layer,
Expand Down Expand Up @@ -621,7 +622,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(
}
sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
} else if (pattern_map.count(gptoss_gemma3_offset)) {
*is_gptoss_gemma3 = true;
is_gemma3 = optional_model_wide_params.count("token_type_ids");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact any model with token_type_ids and matching sliding window pattern will set this is_gemma3 flag true, why not simply name this variable has_token_type_ids?
Or set has_sliding_window here instead, and use below.
Also currently is_gemma3 will be false for causal mask case (no sliding window) within the same model.

auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr();
if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) {
offset = std::make_shared<v15::Squeeze>(offset);
Expand Down Expand Up @@ -756,7 +757,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(
}
OPENVINO_ASSERT(pa_arguments.size() == 25);

if (*is_gptoss_gemma3) {
if (is_gemma3) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
Comment on lines +760 to 763
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable naming is tight to gemma3 but it can be generic for any model having has_token_type_ids and has_sliding_window true.
It is currently applied for sliding_window case only, but as a next step it could be extended to causal case as well then this if else will be reduced to single case:

pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));

Suggested change
if (is_gemma3) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
if (has_sliding_window) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));

Expand Down
93 changes: 93 additions & 0 deletions src/core/tests/type_prop/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,5 +270,98 @@ TEST(type_prop, paged_attention_invalid_rank_key_cache) {
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
}

static ov::OutputVector make_args_with_token_type(const std::shared_ptr<ov::op::v0::Parameter>& token_type_ids) {
using namespace ov::op;
const auto query = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{3, 4});
const auto key = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{3, 4});
const auto value = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{3, 4});
const auto key_cache = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{6, 2, 5, 4});
const auto value_cache = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{6, 2, 5, 4});
const auto past_lens = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto subsequence_begins = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto block_indices = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{15});
const auto block_indices_begins = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{8});
const auto scale = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{});
const auto sliding_window = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto alibi_slopes = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{9});
const auto max_context_len = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto score_aggregation_window = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto rotated_block_indices = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{3});
const auto rotation_deltas = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{12, 1});
const auto rotation_trig_lut = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{256, 4});
const auto xattention_threshold = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{5});
const auto xattention_block_size = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto xattention_stride = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto sinks = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{1, 2, 1, 1});
const auto adaptive_rkv_start_size = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto adaptive_rkv_evictable_sizes = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto adaptive_rkv_diversity_block_set_indices =
std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{10});
const auto adaptive_rkv_diversity_block_set_indices_begins =
std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});

return {query,
key,
value,
key_cache,
value_cache,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids};
}

TEST(type_prop, paged_attention_token_type_ids_1d) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{3});
const auto args = make_args_with_token_type(token_type_ids);
const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), ov::element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{3, 4}));
}

TEST(type_prop, paged_attention_token_type_ids_2d) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{1, 3});
const auto args = make_args_with_token_type(token_type_ids);
const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), ov::element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{3, 4}));
}

TEST(type_prop, paged_attention_token_type_ids_dynamic_shape) {
const auto token_type_ids =
std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{ov::Dimension::dynamic()});
const auto args = make_args_with_token_type(token_type_ids);
EXPECT_NO_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args));
}

TEST(type_prop, paged_attention_invalid_type_token_type_ids) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::f32, ov::PartialShape{3});
const auto args = make_args_with_token_type(token_type_ids);
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
}

TEST(type_prop, paged_attention_invalid_rank_token_type_ids) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{1, 1, 3});
const auto args = make_args_with_token_type(token_type_ids);
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
}

} // namespace testing
} // namespace ov
Loading