-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Minor improvements to token_type_ids extension for PA
#34661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1abe8ec
34763dd
e7f8238
daef794
fdb3a73
829c430
a04d165
96188fb
4d9f607
1f6a6d1
2bf68b5
bcbb855
ed3374d
0fd5001
5fbbf5e
81ab320
4e0d5ac
5f5af24
362cb80
7841c70
9810dc3
890804b
6a62dda
dfc6e1f
a412f6c
810130d
acbd73e
2da2303
03e9935
e2347e1
7176099
af07897
c99b9c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -441,9 +441,10 @@ ov::pass::StateManagementPattern::StateManagementPattern( | |||||||||||||||||
| // Shared flag to track whether the model is Gemma3, set when any layer matches | ||||||||||||||||||
| // the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check, | ||||||||||||||||||
| // this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids). | ||||||||||||||||||
| auto is_gptoss_gemma3 = std::make_shared<bool>(false); | ||||||||||||||||||
| bool is_gemma3 = false; | ||||||||||||||||||
|
|
||||||||||||||||||
| ov::matcher_pass_callback callback = [=, | ||||||||||||||||||
| &is_gemma3, | ||||||||||||||||||
| &kv_parameters, | ||||||||||||||||||
| &model_wide_params, | ||||||||||||||||||
| &block_indices_inputs_for_each_layer, | ||||||||||||||||||
|
|
@@ -621,7 +622,7 @@ ov::pass::StateManagementPattern::StateManagementPattern( | |||||||||||||||||
| } | ||||||||||||||||||
| sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset); | ||||||||||||||||||
| } else if (pattern_map.count(gptoss_gemma3_offset)) { | ||||||||||||||||||
| *is_gptoss_gemma3 = true; | ||||||||||||||||||
| is_gemma3 = optional_model_wide_params.count("token_type_ids"); | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact any model with |
||||||||||||||||||
| auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr(); | ||||||||||||||||||
| if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) { | ||||||||||||||||||
| offset = std::make_shared<v15::Squeeze>(offset); | ||||||||||||||||||
|
|
@@ -756,7 +757,7 @@ ov::pass::StateManagementPattern::StateManagementPattern( | |||||||||||||||||
| } | ||||||||||||||||||
| OPENVINO_ASSERT(pa_arguments.size() == 25); | ||||||||||||||||||
|
|
||||||||||||||||||
| if (*is_gptoss_gemma3) { | ||||||||||||||||||
| if (is_gemma3) { | ||||||||||||||||||
| pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params)); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {})); | ||||||||||||||||||
|
Comment on lines
+760
to
763
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable naming is tight to gemma3 but it can be generic for any model having
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define this variable inside the callback?