Skip to content

Commit 55c4fd5

Browse files
committed
Transpose V tensor for Softmax - Slice - Matmul.
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
1 parent 9300858 commit 55c4fd5

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model_utils.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,16 @@ class TransposeValueTensors : public ov::pass::MatcherPass {
7171
}
7272
};
7373

74-
// llama2 pattern for value tensor concate
75-
class TransposeValueTensors_llama2 : public TransposeValueTensors {
74+
// MHA (Multi-Head Attention) pattern for value tensor concatenation
75+
class TransposeValueTensors_MHA : public TransposeValueTensors {
7676
public:
77-
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_llama2");
78-
TransposeValueTensors_llama2(Context::Ref ctx) {
79-
register_matcher_llama2(ctx);
77+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_MHA");
78+
TransposeValueTensors_MHA(Context::Ref ctx) {
79+
register_matcher_mha(ctx);
8080
}
8181

8282
private:
83-
void register_matcher_llama2(Context::Ref ctx) {
83+
void register_matcher_mha(Context::Ref ctx) {
8484
auto param = opp::wrap_type<ov::op::v0::Parameter>();
8585
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
8686
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
@@ -101,23 +101,24 @@ class TransposeValueTensors_llama2 : public TransposeValueTensors {
101101
matched_node_concat,
102102
matched_node_transpose,
103103
matched_node_matmul);
104-
LOG_DEBUG("vtensors transposed: LLama2 pattern");
104+
LOG_DEBUG("vtensors transposed: MHA pattern");
105105
return true;
106106
};
107-
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_llama2"), std::move(callback));
107+
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_MHA"), std::move(callback));
108108
}
109109
};
110110

111-
// llama3, phi3, mistral, etc, concate value tensors with broadcasting
112-
class TransposeValueTensors_llama3 : public TransposeValueTensors {
111+
// GQA (Grouped Query Attention) pattern for value tensors with broadcasting
112+
// Used by llama3, phi3, mistral, GPT-OSS, etc.
113+
class TransposeValueTensors_GQA : public TransposeValueTensors {
113114
public:
114-
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_llama3");
115-
TransposeValueTensors_llama3(Context::Ref ctx) {
116-
register_matcher_llama3(ctx);
115+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_GQA");
116+
TransposeValueTensors_GQA(Context::Ref ctx) {
117+
register_matcher_gqa(ctx);
117118
}
118119

119120
private:
120-
void register_matcher_llama3(Context::Ref ctx) {
121+
void register_matcher_gqa(Context::Ref ctx) {
121122
auto param = opp::wrap_type<ov::op::v0::Parameter>();
122123
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
123124
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
@@ -131,7 +132,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
131132

132133
// v8 softmax? what? can be other softmaxes
133134
auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input()});
134-
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, reshape});
135+
// Softmax output maybe sliced when SDPA with sink input is decomposed (e.g. GPT-OSS)
136+
auto maybe_slice = opp::optional<ov::op::v8::Slice>(
137+
{softmax, opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
138+
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({maybe_slice, reshape});
135139

136140
auto callback = [=](ov::pass::pattern::Matcher& m) {
137141
auto& node_to_output = m.get_pattern_value_map();
@@ -177,10 +181,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
177181
matched_reshape->input(1).replace_source_output(reshape_axes_node);
178182

179183
transpose_matmul_b(ctx, matched_param, matched_concat, matched_transpose, matched_matmul);
180-
LOG_DEBUG("vtensors transposed: LLama3 pattern");
184+
LOG_DEBUG("vtensors transposed: GQA pattern");
181185
return true;
182186
};
183-
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_llama3"), std::move(callback));
187+
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_GQA"), std::move(callback));
184188
}
185189
};
186190

@@ -529,8 +533,8 @@ bool ov::npuw::util::optimize_value_tensors(std::shared_ptr<ov::Model> model, bo
529533
}
530534

531535
TransposeValueTensors::Context ctx;
532-
rewr.add_matcher<TransposeValueTensors_llama2>(std::ref(ctx));
533-
rewr.add_matcher<TransposeValueTensors_llama3>(std::ref(ctx));
536+
rewr.add_matcher<TransposeValueTensors_MHA>(std::ref(ctx));
537+
rewr.add_matcher<TransposeValueTensors_GQA>(std::ref(ctx));
534538
rewr.run_on_model(model);
535539

536540
ov::pass::Validate().run_on_model(model);

0 commit comments

Comments
 (0)