Skip to content

Commit a37bcb7

Browse files
[NPUW]Transpose V tensor for Softmax - Slice - Matmul. (#33807)
### Details: GPT-OSS SDPA has sink input. There is a pair of Concat and Slice around Softmax. V tensor transpose could not work for the pattern. This PR extended V tensor transpose for GPT-OSS pattern to eliminate the Permutation in compiler. ### Tickets: - *[EISW-200448](https://jira.devtools.intel.com/browse/EISW-200448)* --------- Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
1 parent 177013e commit a37bcb7

File tree

2 files changed

+62
-41
lines changed

2 files changed

+62
-41
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model_utils.cpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,25 @@ class TransposeValueTensors : public ov::pass::MatcherPass {
7171
}
7272
};
7373

74-
// llama2 pattern for value tensor concate
75-
class TransposeValueTensors_llama2 : public TransposeValueTensors {
74+
// MHA (Multi-Head Attention) pattern for value tensor concatenation
75+
class TransposeValueTensors_MHA : public TransposeValueTensors {
7676
public:
77-
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_llama2");
78-
TransposeValueTensors_llama2(Context::Ref ctx) {
79-
register_matcher_llama2(ctx);
77+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_MHA");
78+
TransposeValueTensors_MHA(Context::Ref ctx) {
79+
register_matcher_mha(ctx);
8080
}
8181

8282
private:
83-
void register_matcher_llama2(Context::Ref ctx) {
83+
void register_matcher_mha(Context::Ref ctx) {
8484
auto param = opp::wrap_type<ov::op::v0::Parameter>();
8585
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
8686
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
8787
auto concat = opp::wrap_type<ov::op::v0::Concat>({convert, transpose});
8888
auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input()});
89-
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, concat});
89+
// Softmax output maybe sliced when SDPA with sink input is decomposed
90+
auto maybe_slice = opp::optional<ov::op::v8::Slice>(
91+
{softmax, opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
92+
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({maybe_slice, concat});
9093

9194
auto callback = [=](ov::pass::pattern::Matcher& m) {
9295
auto& node_to_output = m.get_pattern_value_map();
@@ -101,23 +104,24 @@ class TransposeValueTensors_llama2 : public TransposeValueTensors {
101104
matched_node_concat,
102105
matched_node_transpose,
103106
matched_node_matmul);
104-
LOG_DEBUG("vtensors transposed: LLama2 pattern");
107+
LOG_DEBUG("vtensors transposed: MHA pattern");
105108
return true;
106109
};
107-
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_llama2"), std::move(callback));
110+
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_MHA"), std::move(callback));
108111
}
109112
};
110113

111-
// llama3, phi3, mistral, etc, concate value tensors with broadcasting
112-
class TransposeValueTensors_llama3 : public TransposeValueTensors {
114+
// GQA (Grouped Query Attention) pattern for value tensors with broadcasting
115+
// Used by llama3, phi3, mistral, GPT-OSS, etc.
116+
class TransposeValueTensors_GQA : public TransposeValueTensors {
113117
public:
114-
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_llama3");
115-
TransposeValueTensors_llama3(Context::Ref ctx) {
116-
register_matcher_llama3(ctx);
118+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_GQA");
119+
TransposeValueTensors_GQA(Context::Ref ctx) {
120+
register_matcher_gqa(ctx);
117121
}
118122

119123
private:
120-
void register_matcher_llama3(Context::Ref ctx) {
124+
void register_matcher_gqa(Context::Ref ctx) {
121125
auto param = opp::wrap_type<ov::op::v0::Parameter>();
122126
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
123127
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
@@ -131,7 +135,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
131135

132136
// v8 softmax? what? can be other softmaxes
133137
auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input()});
134-
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, reshape});
138+
// Softmax output maybe sliced when SDPA with sink input is decomposed (e.g. GPT-OSS)
139+
auto maybe_slice = opp::optional<ov::op::v8::Slice>(
140+
{softmax, opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
141+
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({maybe_slice, reshape});
135142

136143
auto callback = [=](ov::pass::pattern::Matcher& m) {
137144
auto& node_to_output = m.get_pattern_value_map();
@@ -177,10 +184,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
177184
matched_reshape->input(1).replace_source_output(reshape_axes_node);
178185

179186
transpose_matmul_b(ctx, matched_param, matched_concat, matched_transpose, matched_matmul);
180-
LOG_DEBUG("vtensors transposed: LLama3 pattern");
187+
LOG_DEBUG("vtensors transposed: GQA pattern");
181188
return true;
182189
};
183-
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_llama3"), std::move(callback));
190+
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_GQA"), std::move(callback));
184191
}
185192
};
186193

@@ -529,8 +536,8 @@ bool ov::npuw::util::optimize_value_tensors(std::shared_ptr<ov::Model> model, bo
529536
}
530537

531538
TransposeValueTensors::Context ctx;
532-
rewr.add_matcher<TransposeValueTensors_llama2>(std::ref(ctx));
533-
rewr.add_matcher<TransposeValueTensors_llama3>(std::ref(ctx));
539+
rewr.add_matcher<TransposeValueTensors_MHA>(std::ref(ctx));
540+
rewr.add_matcher<TransposeValueTensors_GQA>(std::ref(ctx));
534541
rewr.run_on_model(model);
535542

536543
ov::pass::Validate().run_on_model(model);

src/plugins/intel_npu/tests/unit/npuw/transpose_vt.cpp

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ namespace npuw_utest{
2020
}
2121

2222
enum class NetworkKind {
23-
llama2,
24-
llama3
23+
MHA, // Multi-Head Attention (e.g., llama2) - no broadcast
24+
GQA // Grouped Query Attention (e.g., llama3, phi3, mistral, GPT-OSS) - with broadcast
2525
};
2626

2727
typedef std::tuple <
@@ -30,6 +30,7 @@ typedef std::tuple <
3030
bool, // withTranspose - without transpose node - matcher shouldnt detect subgraph, easy way to negative case
3131
bool, // withSDPA - should SDPA layer present or be already unrolled or simplified
3232
bool, // use high precision on attention_mask input
33+
bool, // withSink - SDPA with 6th input (sink) for GPT-OSS pattern
3334
NetworkKind
3435
> OptimizeVTTestParamsTuple;
3536

@@ -41,11 +42,12 @@ struct OptimizeVTTestParams {
4142
_AT(2) withTranspose;
4243
_AT(3) withSDPA;
4344
_AT(4) withHpAttenMask;
44-
_AT(5) kind;
45+
_AT(5) withSink;
46+
_AT(6) kind;
4547
#undef _AT
4648

4749
OptimizeVTTestParams(const OptimizeVTTestParamsTuple& tup) {
48-
std::tie(inputShape, withConvert, withTranspose, withSDPA, withHpAttenMask, kind) = tup;
50+
std::tie(inputShape, withConvert, withTranspose, withSDPA, withHpAttenMask, withSink, kind) = tup;
4951
}
5052
};
5153

@@ -75,8 +77,9 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
7577

7678

7779
// validation of High Precision attention mask - implies enabling SDPA layer to be unrolled,
78-
// and also specific FP16 activation transformation in partitioning
79-
if (test.withSDPA) {
80+
// and also specific FP16 activation transformation in partitioning
81+
// Note: When withSink=true, standard OpenVINO SDPA decomposition is used which doesn't support HP
82+
if (test.withSDPA && !test.withSink) {
8083
std::shared_ptr<::intel_npu::OptionsDesc> options_desc;
8184

8285
auto opt_desc = std::make_shared<::intel_npu::OptionsDesc>();
@@ -187,9 +190,10 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
187190

188191
std::ostringstream result;
189192
result << "npuw_llm_pipeline_" << test.inputShape << "_"
190-
<< (test.kind == NetworkKind::llama3 ? "LLAMA3" : "LLAMA2")
193+
<< (test.kind == NetworkKind::MHA ? "MHA" : "GQA")
191194
<< (test.withConvert ? "_with_convert" : "")
192195
<< (test.withSDPA ? "_SDPA" : "")
196+
<< (test.withSink ? "_Sink" : "")
193197
<< (test.withHpAttenMask ? "_HP" : "")
194198
<< (!test.withTranspose ? "_NEGATIVE" : "");
195199
return result.str();
@@ -212,7 +216,7 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
212216
};
213217

214218
// in case of non broadcast number of input channels significantly smaller
215-
auto numChannels = (test.kind == NetworkKind::llama3) ? 8 : 32;
219+
auto numChannels = (test.kind == NetworkKind::MHA) ? 32 : 8;
216220
auto input_shape = test.inputShape;
217221
auto input_2 = static_cast<int>(test.inputShape[2]);
218222
auto input_3 = static_cast<int>(test.inputShape[3]);
@@ -256,7 +260,7 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
256260

257261
std::shared_ptr<ov::Node> concat_or_reshape = concat;
258262

259-
if (test.kind == NetworkKind::llama3) {
263+
if (test.kind == NetworkKind::GQA) {
260264
auto unsqueeze_pattern = create_shape_constant({2}, "unsqueese_pattern");
261265
auto unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(concat, unsqueeze_pattern);
262266
unsqueeze->set_friendly_name("unsqueeze");
@@ -285,25 +289,32 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
285289
auto k_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, input_shape[2] + 1, input_shape[3]});
286290
k_input->set_friendly_name("k_input");
287291

288-
289292
auto q_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, input_shape[2] + 1, input_shape[3]});
290293
q_input->set_friendly_name("q_input");
291294

292295
auto scale_node = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1});
293296

294-
// TODO: add sdpa subgraph
295-
std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
296-
q_input,
297-
k_input,
298-
concat_or_reshape,
299-
mask_input_1,
300-
scale_node,
301-
false);
297+
std::shared_ptr<ov::Node> sdpa;
298+
ov::ParameterVector params = {param, param2, mask_input, k_input, q_input};
299+
300+
// SDPA with sink (6 inputs) for GPT-OSS pattern
301+
if (test.withSink) {
302+
auto sink = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, 1, 1});
303+
sink->set_friendly_name("sink");
304+
params.push_back(sink);
305+
306+
sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
307+
q_input, k_input, concat_or_reshape, mask_input_1, scale_node, sink, false);
308+
} else {
309+
sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
310+
q_input, k_input, concat_or_reshape, mask_input_1, scale_node, false);
311+
}
312+
302313
sdpa->set_friendly_name("sdpa");
303314
auto result = std::make_shared<ov::op::v0::Result>(sdpa);
304315

305316
result->set_friendly_name("res");
306-
return std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param, param2, mask_input, k_input, q_input});
317+
return std::make_shared<ov::Model>(ov::ResultVector{result}, params);
307318

308319
} else {
309320
auto param3 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, 1, input_shape[2] + 1});
@@ -344,9 +355,11 @@ const std::vector<bool> withSDPA{true, false};
344355

345356
const std::vector<bool> withHpAttenMask{true, false};
346357

358+
const std::vector<bool> withSink{true, false};
359+
347360
const std::vector<NetworkKind> networkKind = {
348-
// llama2 or llama3 type of concat, with convert layer or without
349-
NetworkKind::llama2, NetworkKind::llama3
361+
NetworkKind::MHA, // Multi-Head Attention
362+
NetworkKind::GQA // Grouped Query Attention
350363
};
351364

352365
INSTANTIATE_TEST_SUITE_P(smoke_Run_MatchAndTransposeVT,
@@ -357,6 +370,7 @@ const std::vector<NetworkKind> networkKind = {
357370
::testing::ValuesIn(withBroadCast),
358371
::testing::ValuesIn(withSDPA),
359372
::testing::ValuesIn(withHpAttenMask),
373+
::testing::ValuesIn(withSink),
360374
::testing::ValuesIn(networkKind)),
361375
TransposeVTTest::getTestCaseName);
362376

0 commit comments

Comments
 (0)