@@ -71,16 +71,16 @@ class TransposeValueTensors : public ov::pass::MatcherPass {
7171 }
7272};
7373
74- // llama2 pattern for value tensor concate
75- class TransposeValueTensors_llama2 : public TransposeValueTensors {
74+ // MHA (Multi-Head Attention) pattern for value tensor concatenation
75+ class TransposeValueTensors_MHA : public TransposeValueTensors {
7676public:
77- OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::TransposeValueTensors_llama2 " );
78- TransposeValueTensors_llama2 (Context::Ref ctx) {
79- register_matcher_llama2 (ctx);
77+ OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::TransposeValueTensors_MHA " );
78+ TransposeValueTensors_MHA (Context::Ref ctx) {
79+ register_matcher_mha (ctx);
8080 }
8181
8282private:
83- void register_matcher_llama2 (Context::Ref ctx) {
83+ void register_matcher_mha (Context::Ref ctx) {
8484 auto param = opp::wrap_type<ov::op::v0::Parameter>();
8585 auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input (), opp::any_input ()});
8686 auto convert = opp::optional<ov::op::v0::Convert>({param->output (0 )});
@@ -101,23 +101,24 @@ class TransposeValueTensors_llama2 : public TransposeValueTensors {
101101 matched_node_concat,
102102 matched_node_transpose,
103103 matched_node_matmul);
104- LOG_DEBUG (" vtensors transposed: LLama2 pattern" );
104+ LOG_DEBUG (" vtensors transposed: MHA pattern" );
105105 return true ;
106106 };
107- register_matcher (std::make_shared<opp::Matcher>(matmul, " TransposeValueTensors_llama2 " ), std::move (callback));
107+ register_matcher (std::make_shared<opp::Matcher>(matmul, " TransposeValueTensors_MHA " ), std::move (callback));
108108 }
109109};
110110
111- // llama3, phi3, mistral, etc, concate value tensors with broadcasting
112- class TransposeValueTensors_llama3 : public TransposeValueTensors {
111+ // GQA (Grouped Query Attention) pattern for value tensors with broadcasting
112+ // Used by llama3, phi3, mistral, GPT-OSS, etc.
113+ class TransposeValueTensors_GQA : public TransposeValueTensors {
113114public:
114- OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::TransposeValueTensors_llama3 " );
115- TransposeValueTensors_llama3 (Context::Ref ctx) {
116- register_matcher_llama3 (ctx);
115+ OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::TransposeValueTensors_GQA " );
116+ TransposeValueTensors_GQA (Context::Ref ctx) {
117+ register_matcher_gqa (ctx);
117118 }
118119
119120private:
120- void register_matcher_llama3 (Context::Ref ctx) {
121+ void register_matcher_gqa (Context::Ref ctx) {
121122 auto param = opp::wrap_type<ov::op::v0::Parameter>();
122123 auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input (), opp::any_input ()});
123124 auto convert = opp::optional<ov::op::v0::Convert>({param->output (0 )});
@@ -131,7 +132,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
131132
132133 // v8 softmax? what? can be other softmaxes
133134 auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input ()});
134- auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, reshape});
135+ // Softmax output maybe sliced when SDPA with sink input is decomposed (e.g. GPT-OSS)
136+ auto maybe_slice = opp::optional<ov::op::v8::Slice>(
137+ {softmax, opp::any_input (), opp::any_input (), opp::any_input (), opp::any_input ()});
138+ auto matmul = opp::wrap_type<ov::op::v0::MatMul>({maybe_slice, reshape});
135139
136140 auto callback = [=](ov::pass::pattern::Matcher& m) {
137141 auto & node_to_output = m.get_pattern_value_map ();
@@ -177,10 +181,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
177181 matched_reshape->input (1 ).replace_source_output (reshape_axes_node);
178182
179183 transpose_matmul_b (ctx, matched_param, matched_concat, matched_transpose, matched_matmul);
180- LOG_DEBUG (" vtensors transposed: LLama3 pattern" );
184+ LOG_DEBUG (" vtensors transposed: GQA pattern" );
181185 return true ;
182186 };
183- register_matcher (std::make_shared<opp::Matcher>(matmul, " TransposeValueTensors_llama3 " ), std::move (callback));
187+ register_matcher (std::make_shared<opp::Matcher>(matmul, " TransposeValueTensors_GQA " ), std::move (callback));
184188 }
185189};
186190
@@ -529,8 +533,8 @@ bool ov::npuw::util::optimize_value_tensors(std::shared_ptr<ov::Model> model, bo
529533 }
530534
531535 TransposeValueTensors::Context ctx;
532- rewr.add_matcher <TransposeValueTensors_llama2 >(std::ref (ctx));
533- rewr.add_matcher <TransposeValueTensors_llama3 >(std::ref (ctx));
536+ rewr.add_matcher <TransposeValueTensors_MHA >(std::ref (ctx));
537+ rewr.add_matcher <TransposeValueTensors_GQA >(std::ref (ctx));
534538 rewr.run_on_model (model);
535539
536540 ov::pass::Validate ().run_on_model (model);
0 commit comments