Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,25 @@ class TransposeValueTensors : public ov::pass::MatcherPass {
}
};

// llama2 pattern for value tensor concate
class TransposeValueTensors_llama2 : public TransposeValueTensors {
// MHA (Multi-Head Attention) pattern for value tensor concatenation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you implement test(s) for transformation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkazants Done
Thanks!

class TransposeValueTensors_MHA : public TransposeValueTensors {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_llama2");
TransposeValueTensors_llama2(Context::Ref ctx) {
register_matcher_llama2(ctx);
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_MHA");
TransposeValueTensors_MHA(Context::Ref ctx) {
register_matcher_mha(ctx);
}

private:
void register_matcher_llama2(Context::Ref ctx) {
void register_matcher_mha(Context::Ref ctx) {
auto param = opp::wrap_type<ov::op::v0::Parameter>();
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
auto concat = opp::wrap_type<ov::op::v0::Concat>({convert, transpose});
auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input()});
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, concat});
// Softmax output maybe sliced when SDPA with sink input is decomposed
auto maybe_slice = opp::optional<ov::op::v8::Slice>(
{softmax, opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({maybe_slice, concat});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
Expand All @@ -101,23 +104,24 @@ class TransposeValueTensors_llama2 : public TransposeValueTensors {
matched_node_concat,
matched_node_transpose,
matched_node_matmul);
LOG_DEBUG("vtensors transposed: LLama2 pattern");
LOG_DEBUG("vtensors transposed: MHA pattern");
return true;
};
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_llama2"), std::move(callback));
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_MHA"), std::move(callback));
}
};

// llama3, phi3, mistral, etc, concate value tensors with broadcasting
class TransposeValueTensors_llama3 : public TransposeValueTensors {
// GQA (Grouped Query Attention) pattern for value tensors with broadcasting
// Used by llama3, phi3, mistral, GPT-OSS, etc.
class TransposeValueTensors_GQA : public TransposeValueTensors {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_llama3");
TransposeValueTensors_llama3(Context::Ref ctx) {
register_matcher_llama3(ctx);
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::TransposeValueTensors_GQA");
TransposeValueTensors_GQA(Context::Ref ctx) {
register_matcher_gqa(ctx);
}

private:
void register_matcher_llama3(Context::Ref ctx) {
void register_matcher_gqa(Context::Ref ctx) {
auto param = opp::wrap_type<ov::op::v0::Parameter>();
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
Expand All @@ -131,7 +135,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {

// v8 softmax? what? can be other softmaxes
auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input()});
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, reshape});
// Softmax output maybe sliced when SDPA with sink input is decomposed (e.g. GPT-OSS)
auto maybe_slice = opp::optional<ov::op::v8::Slice>(
{softmax, opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({maybe_slice, reshape});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
Expand Down Expand Up @@ -177,10 +184,10 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
matched_reshape->input(1).replace_source_output(reshape_axes_node);

transpose_matmul_b(ctx, matched_param, matched_concat, matched_transpose, matched_matmul);
LOG_DEBUG("vtensors transposed: LLama3 pattern");
LOG_DEBUG("vtensors transposed: GQA pattern");
return true;
};
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_llama3"), std::move(callback));
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors_GQA"), std::move(callback));
}
};

Expand Down Expand Up @@ -529,8 +536,8 @@ bool ov::npuw::util::optimize_value_tensors(std::shared_ptr<ov::Model> model, bo
}

TransposeValueTensors::Context ctx;
rewr.add_matcher<TransposeValueTensors_llama2>(std::ref(ctx));
rewr.add_matcher<TransposeValueTensors_llama3>(std::ref(ctx));
rewr.add_matcher<TransposeValueTensors_MHA>(std::ref(ctx));
rewr.add_matcher<TransposeValueTensors_GQA>(std::ref(ctx));
rewr.run_on_model(model);

ov::pass::Validate().run_on_model(model);
Expand Down
56 changes: 35 additions & 21 deletions src/plugins/intel_npu/tests/unit/npuw/transpose_vt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace npuw_utest{
}

enum class NetworkKind {
llama2,
llama3
MHA, // Multi-Head Attention (e.g., llama2) - no broadcast
GQA // Grouped Query Attention (e.g., llama3, phi3, mistral, GPT-OSS) - with broadcast
};

typedef std::tuple <
Expand All @@ -30,6 +30,7 @@ typedef std::tuple <
bool, // withTranspose - without transpose node - matcher shouldnt detect subgraph, easy way to negative case
bool, // withSDPA - should SDPA layer present or be already unrolled or simplified
bool, // use high precision on attention_mask input
bool, // withSink - SDPA with 6th input (sink) for GPT-OSS pattern
NetworkKind
> OptimizeVTTestParamsTuple;

Expand All @@ -41,11 +42,12 @@ struct OptimizeVTTestParams {
_AT(2) withTranspose;
_AT(3) withSDPA;
_AT(4) withHpAttenMask;
_AT(5) kind;
_AT(5) withSink;
_AT(6) kind;
#undef _AT

OptimizeVTTestParams(const OptimizeVTTestParamsTuple& tup) {
std::tie(inputShape, withConvert, withTranspose, withSDPA, withHpAttenMask, kind) = tup;
std::tie(inputShape, withConvert, withTranspose, withSDPA, withHpAttenMask, withSink, kind) = tup;
}
};

Expand Down Expand Up @@ -75,8 +77,9 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT


// validation of High Precision attention mask - implies enabling SDPA layer to be unrolled,
// and also specific FP16 activation transformation in partitioning
if (test.withSDPA) {
// and also specific FP16 activation transformation in partitioning
// Note: When withSink=true, standard OpenVINO SDPA decomposition is used which doesn't support HP
if (test.withSDPA && !test.withSink) {
std::shared_ptr<::intel_npu::OptionsDesc> options_desc;

auto opt_desc = std::make_shared<::intel_npu::OptionsDesc>();
Expand Down Expand Up @@ -187,9 +190,10 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT

std::ostringstream result;
result << "npuw_llm_pipeline_" << test.inputShape << "_"
<< (test.kind == NetworkKind::llama3 ? "LLAMA3" : "LLAMA2")
<< (test.kind == NetworkKind::MHA ? "MHA" : "GQA")
<< (test.withConvert ? "_with_convert" : "")
<< (test.withSDPA ? "_SDPA" : "")
<< (test.withSink ? "_Sink" : "")
<< (test.withHpAttenMask ? "_HP" : "")
<< (!test.withTranspose ? "_NEGATIVE" : "");
return result.str();
Expand All @@ -212,7 +216,7 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
};

// in case of non broadcast number of input channels significantly smaller
auto numChannels = (test.kind == NetworkKind::llama3) ? 8 : 32;
auto numChannels = (test.kind == NetworkKind::MHA) ? 32 : 8;
auto input_shape = test.inputShape;
auto input_2 = static_cast<int>(test.inputShape[2]);
auto input_3 = static_cast<int>(test.inputShape[3]);
Expand Down Expand Up @@ -256,7 +260,7 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT

std::shared_ptr<ov::Node> concat_or_reshape = concat;

if (test.kind == NetworkKind::llama3) {
if (test.kind == NetworkKind::GQA) {
auto unsqueeze_pattern = create_shape_constant({2}, "unsqueese_pattern");
auto unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(concat, unsqueeze_pattern);
unsqueeze->set_friendly_name("unsqueeze");
Expand Down Expand Up @@ -285,25 +289,32 @@ class TransposeVTTest : public testing::WithParamInterface<OptimizeVTTestParamsT
auto k_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, input_shape[2] + 1, input_shape[3]});
k_input->set_friendly_name("k_input");


auto q_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, input_shape[2] + 1, input_shape[3]});
q_input->set_friendly_name("q_input");

auto scale_node = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1});

// TODO: add sdpa subgraph
std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
q_input,
k_input,
concat_or_reshape,
mask_input_1,
scale_node,
false);
std::shared_ptr<ov::Node> sdpa;
ov::ParameterVector params = {param, param2, mask_input, k_input, q_input};

// SDPA with sink (6 inputs) for GPT-OSS pattern
if (test.withSink) {
auto sink = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, 1, 1});
sink->set_friendly_name("sink");
params.push_back(sink);

sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
q_input, k_input, concat_or_reshape, mask_input_1, scale_node, sink, false);
} else {
sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(
q_input, k_input, concat_or_reshape, mask_input_1, scale_node, false);
}

sdpa->set_friendly_name("sdpa");
auto result = std::make_shared<ov::op::v0::Result>(sdpa);

result->set_friendly_name("res");
return std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param, param2, mask_input, k_input, q_input});
return std::make_shared<ov::Model>(ov::ResultVector{result}, params);

} else {
auto param3 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 32, 1, input_shape[2] + 1});
Expand Down Expand Up @@ -344,9 +355,11 @@ const std::vector<bool> withSDPA{true, false};

const std::vector<bool> withHpAttenMask{true, false};

const std::vector<bool> withSink{true, false};

const std::vector<NetworkKind> networkKind = {
// llama2 or llama3 type of concat, with convert layer or without
NetworkKind::llama2, NetworkKind::llama3
NetworkKind::MHA, // Multi-Head Attention
NetworkKind::GQA // Grouped Query Attention
};

INSTANTIATE_TEST_SUITE_P(smoke_Run_MatchAndTransposeVT,
Expand All @@ -357,6 +370,7 @@ const std::vector<NetworkKind> networkKind = {
::testing::ValuesIn(withBroadCast),
::testing::ValuesIn(withSDPA),
::testing::ValuesIn(withHpAttenMask),
::testing::ValuesIn(withSink),
::testing::ValuesIn(networkKind)),
TransposeVTTest::getTestCaseName);

Expand Down
Loading