-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[NPUW] Replacing longrope pattern with precalculated values #33011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2a01767
a1e9ab1
6862daa
bfbd69e
f2c3a80
1a2963b
3a23df4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| #include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" | ||
|
|
||
| namespace opp = ov::pass::pattern; | ||
| namespace pre_compute = ov::npuw::patterns::pre_compute; | ||
|
|
||
| namespace { | ||
| // TODO: copied from common tests | ||
|
|
@@ -49,6 +50,46 @@ static ov::OutputVector makeCosSinCache(const size_t max_position_embeddings, | |
|
|
||
| return {Cos, Sin}; | ||
| } | ||
|
|
||
| void replaceSinCosByCache(int max_prompt_len, const ov::OutputVector& cache, const pre_compute::RopePatternDesc* rpe) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be to align with |
||
| auto inv_freq_size = ov::shape_size(rpe->matched_inv_freq->get_shape()); | ||
|
|
||
| LOG_VERB("Making sin-cos cache of size: " << max_prompt_len << "x" << inv_freq_size); | ||
|
|
||
| // Step 1: Define axis (gather along axis 1) | ||
| auto axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); | ||
|
|
||
| // Step 2: Apply Gather for cos and sin | ||
| auto gather_cos = std::make_shared<ov::op::v8::Gather>(cache[0], rpe->matched_position_ids, axis); | ||
| auto gather_sin = std::make_shared<ov::op::v8::Gather>(cache[1], rpe->matched_position_ids, axis); | ||
| LOG_VERB("Created gather op facilitate LUT search: " << gather_cos->get_name() << ", " << gather_cos->get_shape()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need verbose for the sin also? |
||
|
|
||
| // Step 2: convert fp16->fp32 | ||
| auto cos_fp32 = std::make_shared<ov::op::v0::Convert>(gather_cos, ov::element::f32); | ||
| auto sin_fp32 = std::make_shared<ov::op::v0::Convert>(gather_sin, ov::element::f32); | ||
|
|
||
| // Create the squeeze operation required after gather | ||
| auto squeeze_cos = std::make_shared<ov::op::v0::Squeeze>(cos_fp32, axis); | ||
| auto squeeze_sin = std::make_shared<ov::op::v0::Squeeze>(sin_fp32, axis); | ||
|
|
||
| LOG_VERB("Created squeeze_cos op to reduce axis=1: " << squeeze_cos->get_name() << ", " | ||
| << squeeze_cos->get_shape()); | ||
| LOG_VERB("Created squeeze_sin op to reduce axis=1: " << squeeze_sin->get_name() << ", " | ||
| << squeeze_sin->get_shape()); | ||
|
|
||
| LOG_VERB("Rope cos detected at: " << rpe->matched_cos->get_name() << ", replacing by cache node: " | ||
| << gather_cos->get_name() << ", " << gather_cos->get_shape()); | ||
| LOG_VERB("Rope sin detected at: " << rpe->matched_sin->get_name() << ", replacing by cache node: " | ||
| << gather_sin->get_name() << ", " << gather_sin->get_shape()); | ||
|
|
||
| // replacing sin with gather op | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cos and sin |
||
| ov::replace_node(rpe->matched_cos, squeeze_cos); | ||
| ov::replace_node(rpe->matched_sin, squeeze_sin); | ||
|
|
||
| // disconnecting gather from rest or subgraph started from concat_1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor comment: typo in |
||
| auto gather_input_to_concat = rpe->matched_concat->input(0); | ||
| gather_input_to_concat.get_source_output().remove_target_input(gather_input_to_concat); | ||
| } | ||
| } // namespace | ||
|
|
||
| ov::npuw::patterns::pre_compute::RopePatternLLama2::RopePatternLLama2() : matcher("sin-cos-matcher") { | ||
|
|
@@ -87,54 +128,114 @@ ov::npuw::patterns::pre_compute::RopePatternLLama2::RopePatternLLama2() : matche | |
| matcher.register_patterns({output_sin, output_cos}, make_matcher_callback()); | ||
| } | ||
|
|
||
| ov::npuw::patterns::pre_compute::RopeCacheMatcher::RopeCacheMatcher(const uint32_t max_prompt_len, | ||
| const std::shared_ptr<ov::Model>& model) { | ||
| auto rpe = std::make_shared<RopePatternLLama2>(); | ||
| ov::npuw::patterns::pre_compute::LongRopePatternPhi::LongRopePatternPhi() : matcher("sin-cos-matcher") { | ||
| auto MakeConstant = []() { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run pass conditionally |
||
| return opp::wrap_type<ov::op::v0::Constant>(); | ||
| }; | ||
|
|
||
| rpe->transform_cb = [&]() { | ||
| auto inv_freq_size = ov::shape_size(rpe->matched_inv_freq->get_shape()); | ||
| auto make_select_pattern = [&](const std::shared_ptr<ov::Node>& position_ids, | ||
| const std::shared_ptr<ov::Node>& inv_freq_short, | ||
| const std::shared_ptr<ov::Node>& inv_freq_long) { | ||
| auto red_max = opp::wrap_type<ov::op::v1::ReduceMax>({position_ids, MakeConstant()}); | ||
| auto add = opp::wrap_type<ov::op::v1::Add>({red_max, MakeConstant()}); | ||
| // max(position_ids) + 1 <= original_max_position_embeddings | ||
| auto leq = opp::wrap_type<ov::op::v1::LessEqual>({add, MakeConstant()}); | ||
|
|
||
| LOG_VERB("Making sin-cos cache of size: " << max_prompt_len << "x" << inv_freq_size); | ||
| auto inv_freq_short_conv = opp::optional<ov::op::v0::Convert>({inv_freq_short->output(0)}); | ||
| auto inv_freq_long_conv = opp::optional<ov::op::v0::Convert>({inv_freq_long->output(0)}); | ||
|
|
||
| // shapes that matches max possible position | ||
| auto cache = makeCosSinCache(max_prompt_len, rpe->matched_inv_freq); | ||
| // max(position_ids) + 1 <= original_max_position_embeddings ? short_factor : long_factor; | ||
| auto select = opp::wrap_type<ov::op::v1::Select>({leq, inv_freq_short_conv, inv_freq_long_conv}); | ||
| auto unsqueeze = opp::optional<ov::op::v0::Unsqueeze>({select, MakeConstant()}); | ||
| auto unsqueeze_1 = opp::optional<ov::op::v0::Unsqueeze>({unsqueeze, MakeConstant()}); | ||
|
|
||
| return std::make_tuple(unsqueeze_1, leq, red_max); | ||
| }; | ||
|
|
||
| auto position_ids = opp::wrap_type<ov::op::v0::Parameter>(); | ||
|
|
||
| auto inv_freq_short = MakeConstant(); | ||
| auto inv_freq_long = MakeConstant(); | ||
|
|
||
| auto select_cond_max_pos_id = make_select_pattern(position_ids, inv_freq_short, inv_freq_long); | ||
| auto select = std::get<0>(select_cond_max_pos_id); | ||
| auto cond = std::get<1>(select_cond_max_pos_id); | ||
| auto max_pos_id = std::get<2>(select_cond_max_pos_id); | ||
|
|
||
| auto shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({opp::any_input()}); | ||
| auto gather = opp::wrap_type<ov::op::v8::Gather>({shape_of, opp::any_input(), opp::any_input()}); | ||
| auto concat_1 = opp::wrap_type<ov::op::v0::Concat>({gather, opp::any_input(), opp::any_input()}); | ||
| // here we can seen inverse frequencies as a parameter or constant depending on partitioner passes | ||
| auto broadcast = opp::wrap_type<ov::op::v3::Broadcast>({select, concat_1}); | ||
| auto unsqueeze = opp::wrap_type<ov::op::v0::Unsqueeze>({position_ids, MakeConstant()}); | ||
| auto convert = opp::wrap_type<ov::op::v0::Convert>({unsqueeze}); | ||
| auto matmul = opp::wrap_type<ov::op::v0::MatMul>({broadcast, convert}); | ||
| auto transpose = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input()}); | ||
| auto concat_2 = opp::wrap_type<ov::op::v0::Concat>({transpose, opp::any_input()}); | ||
| auto output_sin = opp::wrap_type<ov::op::v0::Sin>({concat_2}); | ||
| auto output_cos = opp::wrap_type<ov::op::v0::Cos>({concat_2}); | ||
|
|
||
| // Step 1: Define axis (gather along axis 1) | ||
| auto axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); | ||
| init_cb = [=](const auto& matches) { | ||
| const auto& map_sin = matches.at(output_sin)[0]; | ||
| const auto& map_cos = matches.at(output_cos)[0]; | ||
|
|
||
| // Step 2: Apply Gather for cos and sin | ||
| auto gather_cos = std::make_shared<ov::op::v8::Gather>(cache[0], rpe->matched_position_ids, axis); | ||
| auto gather_sin = std::make_shared<ov::op::v8::Gather>(cache[1], rpe->matched_position_ids, axis); | ||
| LOG_VERB("Created gather op facilitate LUT search: " << gather_cos->get_name() << ", " | ||
| << gather_cos->get_shape()); | ||
| this->matched_position_ids = map_sin.at(position_ids).get_node_shared_ptr(); | ||
| this->matched_concat = map_sin.at(concat_1).get_node_shared_ptr(); | ||
| this->matched_inv_freq = map_sin.at(inv_freq_short).get_node_shared_ptr(); | ||
| this->matched_inv_freq_long = map_sin.at(inv_freq_long).get_node_shared_ptr(); | ||
| this->matched_cond = map_sin.at(cond).get_node_shared_ptr(); | ||
| this->max_pos_id = map_sin.at(max_pos_id).get_node_shared_ptr(); | ||
|
|
||
| // Step 2: convert fp16->fp32 | ||
| auto cos_fp32 = std::make_shared<ov::op::v0::Convert>(gather_cos, ov::element::f32); | ||
| auto sin_fp32 = std::make_shared<ov::op::v0::Convert>(gather_sin, ov::element::f32); | ||
| this->matched_cos = map_cos.at(output_cos).get_node_shared_ptr(); | ||
| this->matched_sin = map_sin.at(output_sin).get_node_shared_ptr(); | ||
|
|
||
| // Create the squeeze operation required after gather | ||
| auto squeeze_cos = std::make_shared<ov::op::v0::Squeeze>(cos_fp32, axis); | ||
| auto squeeze_sin = std::make_shared<ov::op::v0::Squeeze>(sin_fp32, axis); | ||
| LOG_VERB("Rope found : sin=" << matched_sin->get_name() << ", cos=" << matched_cos->get_name()); | ||
|
|
||
| LOG_VERB("Created squeeze_cos op to reduce axis=1: " << squeeze_cos->get_name() << ", " | ||
| << squeeze_cos->get_shape()); | ||
| LOG_VERB("Created squeeze_sin op to reduce axis=1: " << squeeze_sin->get_name() << ", " | ||
| << squeeze_sin->get_shape()); | ||
| return true; | ||
| }; | ||
|
|
||
| LOG_VERB("Rope cos detected at: " << rpe->matched_cos->get_name() << ", replacing by cache node: " | ||
| << gather_cos->get_name() << ", " << gather_cos->get_shape()); | ||
| LOG_VERB("Rope sin detected at: " << rpe->matched_sin->get_name() << ", replacing by cache node: " | ||
| << gather_sin->get_name() << ", " << gather_sin->get_shape()); | ||
| matcher.register_patterns({output_sin, output_cos}, make_matcher_callback()); | ||
| } | ||
|
|
||
| // replacing sin with gather op | ||
| ov::replace_node(rpe->matched_cos, squeeze_cos); | ||
| ov::replace_node(rpe->matched_sin, squeeze_sin); | ||
| ov::npuw::patterns::pre_compute::RopeCacheMatcher::RopeCacheMatcher(const uint32_t max_prompt_len, | ||
| const std::shared_ptr<ov::Model>& model, | ||
| const std::string& longrope_input_name) { | ||
| auto rpe = std::make_shared<RopePatternLLama2>(); | ||
|
|
||
| // disconnecting gather from rest or subgraph started from concat_1 | ||
| auto gather_input_to_concat = rpe->matched_concat->input(0); | ||
| gather_input_to_concat.get_source_output().remove_target_input(gather_input_to_concat); | ||
| rpe->transform_cb = [&]() { | ||
| auto cache = makeCosSinCache(max_prompt_len, rpe->matched_inv_freq); | ||
| replaceSinCosByCache(max_prompt_len, cache, rpe.get()); | ||
| }; | ||
| rpe->run_on_model(model); | ||
|
|
||
| auto long_rpe = std::make_shared<LongRopePatternPhi>(); | ||
|
|
||
| std::shared_ptr<ov::op::v0::Parameter> max_pos_id_param; | ||
| long_rpe->transform_cb = [&]() { | ||
| auto cache_short = makeCosSinCache(max_prompt_len, long_rpe->matched_inv_freq); | ||
| auto cache_long = makeCosSinCache(max_prompt_len, long_rpe->matched_inv_freq_long); | ||
|
|
||
| auto select_cos = std::make_shared<ov::op::v1::Select>(long_rpe->matched_cond, cache_short[0], cache_long[0]); | ||
| auto select_sin = std::make_shared<ov::op::v1::Select>(long_rpe->matched_cond, cache_short[1], cache_long[1]); | ||
|
|
||
| replaceSinCosByCache(max_prompt_len, {select_cos, select_sin}, long_rpe.get()); | ||
|
|
||
| auto max_pos_id_out = long_rpe->max_pos_id->output(0); | ||
| max_pos_id_param.reset(new ov::op::v0::Parameter(max_pos_id_out.get_element_type(), {1})); | ||
| max_pos_id_param->set_friendly_name(longrope_input_name); | ||
| max_pos_id_out.replace(max_pos_id_param->output(0)); | ||
| }; | ||
| long_rpe->run_on_model(model); | ||
|
|
||
| if (max_pos_id_param) { | ||
| model->add_parameters({max_pos_id_param}); | ||
| for (auto&& input : model->inputs()) { | ||
| if (input.get_node() == max_pos_id_param.get()) { | ||
| input.set_names({max_pos_id_param->get_friendly_name()}); | ||
| } | ||
| } | ||
| } | ||
| model->validate_nodes_and_infer_types(); | ||
| } | ||
|
|
||
| ov::npuw::patterns::pre_compute::RopeInverseFreq::RopeInverseFreq( | ||
|
|
@@ -154,6 +255,6 @@ ov::npuw::patterns::pre_compute::RopeInverseFreq::RopeInverseFreq( | |
| } | ||
|
|
||
| bool ov::npuw::patterns::pre_compute::RopeCache::run_on_model(const std::shared_ptr<ov::Model>& model) { | ||
| ov::npuw::patterns::pre_compute::RopeCacheMatcher ropeCache(m_max_prompt_len, model); | ||
| ov::npuw::patterns::pre_compute::RopeCacheMatcher ropeCache(m_max_prompt_len, model, m_longrope_input_name); | ||
| return true; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if attention mask is right padded