-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[NPUW] gemma-2 patterns added to preserve tail constants matcher #32465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
6ad8b26
1209046
8c2da6f
29cb65e
1ea470e
3c19c32
be1e6d4
88c3e74
3125293
84cb7bc
5274f04
268048a
0a811c9
b872130
2ca0339
a9af6f9
fb12f9a
240e41b
a0186bd
d673f36
d52dc2e
8e04ee1
87029d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,23 @@ | |
|
|
||
| namespace opp = ov::pass::pattern; | ||
|
|
||
| // specific function that match subgraph appeared as result of lpt transformations | ||
| auto match_down_up_convert_subgraph_after_lpt = [](const ov::Output<ov::Node>& input) { | ||
| auto upconvert = opp::wrap_type<ov::op::v0::Convert>({input}, opp::type_matches(ov::element::f32)); | ||
|
|
||
| auto upscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0)); | ||
| auto upmul = opp::wrap_type<ov::op::v1::Multiply>({upconvert, upscale}); | ||
|
|
||
| auto downscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0)); | ||
| auto downmul = opp::wrap_type<ov::op::v1::Multiply>({upmul, downscale}); | ||
|
|
||
| auto downconvert = | ||
| opp::wrap_type<ov::op::v0::Convert>({downmul}, | ||
| opp::type_matches_any({ov::element::f8e4m3, ov::element::f8e5m2})); | ||
|
|
||
| return downconvert; | ||
| }; | ||
|
|
||
| class RemoveEmptyKVTensors : public ov::pass::MatcherPass { | ||
| public: | ||
| OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::RemoveEmptyKVTensors"); | ||
|
|
@@ -54,7 +71,10 @@ class RemoveEmptyKVTensors : public ov::pass::MatcherPass { | |
|
|
||
| RemoveEmptyKVTensors(Context::Ref ctx) { | ||
| auto param = opp::wrap_type<ov::op::v0::Parameter>(); | ||
| auto concat = opp::wrap_type<ov::op::v0::Concat>({param, opp::any_input()}); | ||
| auto param_or = | ||
| std::make_shared<opp::op::Or>(ov::OutputVector{param, match_down_up_convert_subgraph_after_lpt(param)}); | ||
|
|
||
| auto concat = opp::wrap_type<ov::op::v0::Concat>({param_or, opp::any_input()}); | ||
|
|
||
| auto callback = [=](opp::Matcher& m) { | ||
| auto& node_to_output = m.get_pattern_value_map(); | ||
|
|
@@ -63,15 +83,28 @@ class RemoveEmptyKVTensors : public ov::pass::MatcherPass { | |
|
|
||
| ctx.get().old_params.push_back(matched_param); | ||
|
|
||
| auto users = matched_param->get_users(); | ||
| if (users.size() == 2u) { | ||
| auto shapeof_node = ov::is_type<ov::op::v3::ShapeOf>(users[0]) ? users[0] : users[1]; | ||
| NPUW_ASSERT(ov::is_type<ov::op::v3::ShapeOf>(shapeof_node)); | ||
| auto cst_node = | ||
| ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, matched_param->get_shape()); | ||
| ov::replace_node(shapeof_node, cst_node); | ||
| } else { | ||
| NPUW_ASSERT(users.size() == 1u); | ||
| // Use concat's first input source node to find ShapeOf users. | ||
| // This works universally for both plain parameter and down_up_convert subgraph cases, | ||
| // because in the subgraph case matched_param->get_users() would return the Convert | ||
| // node (first node of the subgraph), not the ShapeOf. | ||
| auto concat_input0_node = matched_node_concat->input(0).get_source_output().get_node_shared_ptr(); | ||
| auto users = concat_input0_node->get_users(); | ||
|
|
||
| // In subgraph case the parameter itself may also have a ShapeOf user, | ||
| // so check both the concat input node and the parameter. | ||
| if (concat_input0_node != matched_param) { | ||
| auto param_users = matched_param->get_users(); | ||
| users.insert(users.end(), param_users.begin(), param_users.end()); | ||
| } | ||
|
|
||
| // Remove duplicates (concat itself will appear in users) | ||
|
||
| // Find and replace ShapeOf nodes with constants | ||
| for (auto& user : users) { | ||
| if (ov::is_type<ov::op::v3::ShapeOf>(user)) { | ||
| auto cst_node = | ||
| ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, matched_param->get_shape()); | ||
| ov::replace_node(user, cst_node); | ||
| } | ||
| } | ||
|
|
||
| // Redirect second concat input to every node which reads from concat | ||
|
|
@@ -323,22 +356,6 @@ class GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { | |
| class RedirectNewKvToOutput : public ov::pass::MatcherPass { | ||
| public: | ||
| RedirectNewKvToOutput() { | ||
| auto match_down_up_convert_subgraph = [](const ov::Output<ov::Node>& input) { | ||
| auto upconvert = opp::wrap_type<ov::op::v0::Convert>({input}, opp::type_matches(ov::element::f32)); | ||
|
|
||
| auto upscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0)); | ||
| auto upmul = opp::wrap_type<ov::op::v1::Multiply>({upconvert, upscale}); | ||
|
|
||
| auto downscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0)); | ||
| auto downmul = opp::wrap_type<ov::op::v1::Multiply>({upmul, downscale}); | ||
|
|
||
| auto downconvert = | ||
| opp::wrap_type<ov::op::v0::Convert>({downmul}, | ||
| opp::type_matches_any({ov::element::f8e4m3, ov::element::f8e5m2})); | ||
|
|
||
| return downconvert; | ||
| }; | ||
|
|
||
| // example of fp8 inputs to concat | ||
| // input0 : float8e4m3[1,32,1151,96] | ||
| // input1 : float8e4m3[1,32,1,96] | ||
|
|
@@ -348,13 +365,13 @@ class RedirectNewKvToOutput : public ov::pass::MatcherPass { | |
| // TODO: this matcher logic better to cover with unit-tests | ||
| auto input0 = opp::wrap_type<ov::op::v0::Parameter>(); | ||
| auto input0_or = | ||
| std::make_shared<opp::op::Or>(ov::OutputVector{input0, match_down_up_convert_subgraph(input0)}); | ||
| std::make_shared<opp::op::Or>(ov::OutputVector{input0, match_down_up_convert_subgraph_after_lpt(input0)}); | ||
|
|
||
| auto input1 = opp::any_input(); | ||
|
|
||
| auto kv_concat = opp::wrap_type<ov::op::v0::Concat>({input0_or, input1}); | ||
| auto result1 = opp::wrap_type<ov::op::v0::Result>(kv_concat); | ||
| auto result2 = opp::wrap_type<ov::op::v0::Result>(match_down_up_convert_subgraph(kv_concat)); | ||
| auto result2 = opp::wrap_type<ov::op::v0::Result>(match_down_up_convert_subgraph_after_lpt(kv_concat)); | ||
|
|
||
| auto result_or = std::make_shared<opp::op::Or>(ov::OutputVector{result1, result2}); | ||
|
|
||
|
|
@@ -1162,6 +1179,7 @@ struct NPUDesc { | |
| std::string arch; | ||
| int64_t max_tiles = 0; | ||
| bool compiler_dq = false; | ||
| bool compiler_matmul_gate = false; | ||
| int64_t compiler_ver = 0; | ||
| bool support_flash_attention_tile = false; | ||
| }; | ||
|
|
@@ -1199,6 +1217,19 @@ std::optional<NPUDesc> extract_npu_descriptor(const std::shared_ptr<const ov::IP | |
| ov::AnyMap{{ov::intel_npu::compiler_type.name(), target_compiler_type}}) | ||
| .as<int64_t>(); | ||
| } | ||
| LOG_INFO("Compiler version: " << ONEAPI_VERSION_MAJOR(desc.compiler_ver) << "." | ||
| << ONEAPI_VERSION_MINOR(desc.compiler_ver)); | ||
|
|
||
| constexpr std::string_view compiler_gate_support_msg = | ||
| "Compiler: accurate gated matmul (MatMul -> Divide -> Tanh -> Multiply -> Result) : "; | ||
|
|
||
| if (desc.compiler_ver >= ONEAPI_MAKE_VERSION(7, 28)) { | ||
| // accuracy for gated matmul fixed at 7.28 | ||
| desc.compiler_matmul_gate = true; | ||
| LOG_INFO(compiler_gate_support_msg << "supported"); | ||
| } else { | ||
| LOG_WARN(compiler_gate_support_msg << "unsupported"); | ||
| } | ||
|
|
||
| if (desc.arch == "5010" && desc.compiler_ver >= ONEAPI_MAKE_VERSION(7, 29)) { | ||
| // Flash attention tile is supported starting from compiler version 7.29 on NPU5010 | ||
|
|
@@ -1247,6 +1278,13 @@ ov::AnyMap get_baseline_common_config(const std::optional<NPUDesc>& npudesc) { | |
| config.erase("NPUW_DCOFF_TYPE"); | ||
| config.erase("NPUW_DCOFF_SCALE"); | ||
| } | ||
|
|
||
| // default value is ON | ||
| // for compiler versions >= 7.28 value is ON | ||
| // for other compiler versions value is OFF | ||
| if (npudesc.has_value()) { | ||
| config.emplace("NPUW_MM_GATED", (npudesc->compiler_matmul_gate ? "YES" : "NO")); | ||
| } | ||
| return config; | ||
| } | ||
|
|
||
|
|
@@ -1877,7 +1915,6 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m | |
|
|
||
| if (!m_is_embedding) { | ||
| if (!m_use_chunk_prefill) { | ||
| // TODO: sometimes it is ok if we cannot find any empty inputs or not? | ||
| NPUW_ASSERT(remove_empty_kv_inputs(prefill_model)); | ||
| } else { | ||
| LOG_DEBUG("Don't remove input key/values from prefill model."); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!!