Skip to content

Commit c0094c6

Browse files
authored
Update RopeFusion to support Qwen model after SDPA to PA conversion. (#28620)
Details: Update RopeFusion for qwen model Tickets: [CVS-161067](https://jira.devtools.intel.com/browse/CVS-161067)
1 parent 6eed0fb commit c0094c6

File tree

3 files changed

+138
-17
lines changed

3 files changed

+138
-17
lines changed

Diff for: src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

+40-15
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
723723
auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
724724
auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
725725
auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288]
726+
auto position_ids = makePattern();
726727

727728
auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
728729
auto head_size = ov::gen_pattern::Symbol("head_size");
@@ -749,14 +750,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
749750
auto ScatterUpdate_463814 = makePattern<opset3::ScatterUpdate>({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}});
750751
auto slice_Slice_446 =
751752
makePattern<ov::opset8::Slice>({rotary_emb_cos, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});
753+
754+
auto gather_cos_by_pos_ids = makePattern<opset8::Gather>({rotary_emb_cos, position_ids, 1}, {{"batch_dims", 0}});
755+
auto reshape_cos_to_expected_layout =
756+
makePattern<opset8::Reshape>({gather_cos_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}});
757+
752758
auto slice_StridedSlice_446 = GenStridedSlice(rotary_emb_cos,
753759
ScatterUpdate_463814,
754760
{0, INT_MAX},
755761
{1, 1},
756762
1); // tensor_array<f32[1,..4096,1,128]>
757-
auto mul_Multiply_552 =
758-
makePattern<opset1::Multiply>({slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446},
759-
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
763+
auto mul_Multiply_552 = makePattern<opset1::Multiply>(
764+
{slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446 | reshape_cos_to_expected_layout},
765+
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
760766

761767
auto reshape_opt1 = [&](std::shared_ptr<Node> input_BLHS) {
762768
auto ShapeOf_485814 = makePattern<opset1::ShapeOf>({input_BLHS}, {});
@@ -790,18 +796,28 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
790796
makePattern<opset1::Squeeze>({Multiply_567527, -2}); // tensor_array<f32[?,?,32,64]>
791797
auto ListUnpack_586_Squeeze =
792798
makePattern<opset1::Squeeze>({ListUnpack_586_Split->output(0), -2}); // tensor_array<f32[?,?,32,64]>
793-
auto cat_Concat_593 = makePattern<opset1::Concat>({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze},
794-
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>
799+
800+
auto ListUnpack_Squeeze_0_1 =
801+
makePattern<opset1::Reshape>({Multiply_567527, {-1, 1, 32, 64}}, {{"special_zero", false}});
802+
auto ListUnpack_Squeeze_1 =
803+
makePattern<opset1::Reshape>({ListUnpack_586_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
804+
805+
auto cat_Concat_593 = makePattern<opset1::Concat>(
806+
{ListUnpack_586_Squeeze_0 | ListUnpack_Squeeze_0_1, ListUnpack_586_Squeeze | ListUnpack_Squeeze_1},
807+
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>
795808
auto slice_StridedSlice_470 = GenStridedSlice(rotary_emb_sin,
796809
ScatterUpdate_463814,
797810
{0, INT_MAX},
798811
{1, 1},
799812
1); // tensor_array<f32[1,..4096,1,128]>
800813
auto slice_Slice_470 =
801814
makePattern<opset8::Slice>({rotary_emb_sin, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});
802-
auto mul_Multiply_594 =
803-
makePattern<opset1::Multiply>({cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470},
804-
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
815+
auto gather_sin_by_pos_ids = makePattern<opset8::Gather>({rotary_emb_sin, position_ids, 1}, {{"batch_dims", 0}});
816+
auto reshape_sin_to_expected_layout =
817+
makePattern<opset8::Reshape>({gather_sin_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}});
818+
auto mul_Multiply_594 = makePattern<opset1::Multiply>(
819+
{cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470 | reshape_sin_to_expected_layout},
820+
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
805821
auto add_Add_597 = makePattern<opset1::Add>({mul_Multiply_552, mul_Multiply_594},
806822
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
807823

@@ -858,16 +874,25 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
858874
new_args.push_back(pattern_map.at(rotary_emb_cos));
859875
new_args.push_back(pattern_map.at(rotary_emb_sin));
860876

877+
ov::NodeVector rt_from = {pattern_map.at(Multiply_567527).get_node_shared_ptr(),
878+
pattern_map.at(cat_Concat_593).get_node_shared_ptr(),
879+
pattern_map.at(mul_Multiply_594).get_node_shared_ptr(),
880+
pattern_map.at(add_Add_597).get_node_shared_ptr()};
881+
882+
if (pattern_map.count(position_ids)) {
883+
new_args.push_back(pattern_map.at(position_ids));
884+
config.gather_position_arg_id = 3;
885+
rt_from.push_back(pattern_map.at(ListUnpack_Squeeze_0_1).get_node_shared_ptr());
886+
rt_from.push_back(pattern_map.at(ListUnpack_Squeeze_1).get_node_shared_ptr());
887+
} else {
888+
rt_from.push_back(pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr());
889+
rt_from.push_back(pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr());
890+
}
891+
861892
auto old_node = root;
862893
auto new_node = std::make_shared<op::internal::RoPE>(new_args, config);
863894
new_node->set_friendly_name(old_node->get_friendly_name());
864-
ov::copy_runtime_info({pattern_map.at(Multiply_567527).get_node_shared_ptr(),
865-
pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(),
866-
pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(),
867-
pattern_map.at(cat_Concat_593).get_node_shared_ptr(),
868-
pattern_map.at(mul_Multiply_594).get_node_shared_ptr(),
869-
pattern_map.at(add_Add_597).get_node_shared_ptr()},
870-
new_node);
895+
ov::copy_runtime_info(rt_from, new_node);
871896
ov::replace_node(old_node, new_node);
872897
return true;
873898
};

Diff for: src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp

+80
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,86 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
12171217
}
12181218
}
12191219

1220+
TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) {
1221+
using namespace ov;
1222+
1223+
{
1224+
auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1, -1});
1225+
auto qkv = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096});
1226+
1227+
auto qkv_proj = makeOP<opset1::VariadicSplit>({qkv, 2, {4096, 4096, -1}});
1228+
1229+
auto view_Reshape = makeOP<opset1::Reshape>({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}});
1230+
auto slice_Slice_4 = makeOP<opset8::Slice>({view_Reshape, {0}, {128}, {1}, {3}});
1231+
auto slice_Slice = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
1232+
1233+
auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
1234+
auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1, 1}}, {{"special_zero", false}});
1235+
1236+
auto slice_Slice_1 = makeOP<opset8::Gather>({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}});
1237+
auto Reshape_27400 = makeOP<opset1::Reshape>({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}});
1238+
1239+
auto mul_Multiply = makeOP<opset1::Multiply>({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}});
1240+
auto reshape_Reshape = makeOP<opset1::Reshape>({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
1241+
auto ListUnpack_Split = makeOP<opset1::Split>({reshape_Reshape, -2}, {{"num_splits", 2}});
1242+
auto Multiply_54136 =
1243+
makeOP<opset1::Multiply>({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}});
1244+
auto ListUnpack_Squeeze_0 =
1245+
makeOP<opset1::Reshape>({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}});
1246+
auto ListUnpack_Squeeze =
1247+
makeOP<opset1::Reshape>({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
1248+
auto cat_Concat = makeOP<opset1::Concat>({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}});
1249+
1250+
auto slice_Slice_2 = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
1251+
auto slice_Slice_6 = makeOP<opset8::Gather>({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}});
1252+
auto Reshape_27408 = makeOP<opset1::Reshape>({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}});
1253+
auto mul_Multiply_1 = makeOP<opset1::Multiply>({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}});
1254+
auto add_Add = makeOP<opset1::Add>({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}});
1255+
1256+
auto slice_Slice_10 = makeConst(element::f32, ov::Shape({1, 32767, 1, 1}), {1});
1257+
auto view_Reshape_1 = makeOP<opset1::Reshape>({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}});
1258+
auto slice_Slice_11 = makeOP<opset8::Slice>({view_Reshape_1, {0}, {128}, {1}, {3}});
1259+
auto mul_Multiply_2 = makeOP<opset1::Multiply>({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}});
1260+
auto reshape_Reshape_1 = makeOP<opset1::Reshape>({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
1261+
auto ListUnpack_Split_1 = makeOP<opset1::Split>({reshape_Reshape_1, -2}, {{"num_splits", 2}});
1262+
auto Multiply_54139 =
1263+
makeOP<opset1::Multiply>({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}});
1264+
auto ListUnpack_Squeeze_0_1 =
1265+
makeOP<opset1::Reshape>({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}});
1266+
auto ListUnpack_Squeeze_1 =
1267+
makeOP<opset1::Reshape>({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
1268+
auto cat_Concat_2 = makeOP<opset1::Concat>({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}});
1269+
auto mul_Multiply_3 = makeOP<opset1::Multiply>({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}});
1270+
auto add_Add_1 = makeOP<opset1::Add>({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}});
1271+
model = std::make_shared<ov::Model>(ov::NodeVector{add_Add_1}, ov::ParameterVector{position_ids, qkv});
1272+
}
1273+
1274+
manager.register_pass<ov::pass::RoPEFusion>(false);
1275+
1276+
{
1277+
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3});
1278+
auto rotary_emp_sin = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
1279+
auto rotary_emp_cos = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
1280+
auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1, -1});
1281+
auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
1282+
auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1, 1}}, {{"special_zero", false}});
1283+
auto rope = makeOP<ov::op::internal::RoPE>({input, rotary_emp_sin, rotary_emp_cos, Unsqueeze_23750},
1284+
{{"config.slice_start", 4096},
1285+
{"config.slice_stop", 8192},
1286+
{"config.input_trans0213", false},
1287+
{"config.output_trans0213", false},
1288+
{"config.is_interleaved", false},
1289+
{"config.rotary_ndims", 128},
1290+
{"config.is_chatglm", false},
1291+
{"config.support_2d_rope", false},
1292+
{"config.is_qwen", true},
1293+
{"config.head_cnt", 32},
1294+
{"config.head_size", 128},
1295+
{"config.gather_position_arg_id", 3}});
1296+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, position_ids});
1297+
}
1298+
}
1299+
12201300
TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
12211301
disable_rt_info_check();
12221302
const int batch = -1;

Diff for: src/plugins/intel_cpu/src/nodes/rope.cpp

+18-2
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,16 @@ struct RoPE::RoPEExecutorQwen : public RoPE::Executor {
338338
ov::intel_cpu::PlainTensor t_cos(inputs[1]); // [1, present-kv-length, 1, rotary_dims]
339339
ov::intel_cpu::PlainTensor t_sin(inputs[2]); // [1, present-kv-length, 1, rotary_dims]
340340
ov::intel_cpu::PlainTensor t_dst(outputs[0]); // [batch, length, head_cnt, head_size]>
341+
ov::intel_cpu::PlainTensor gather;
342+
341343
auto rotary_dims = t_cos.size(3);
342344

343345
if (m_config.slice_stop - m_config.slice_start > 0) {
344346
t_src = t_src.slice(2, m_config.slice_start, m_config.slice_stop);
345347
}
348+
if (m_config.gather_position_arg_id > 0) {
349+
gather.reset(inputs[m_config.gather_position_arg_id]);
350+
}
346351

347352
auto batch_size = t_src.size(0);
348353
auto seq_len = t_src.size(1);
@@ -351,9 +356,20 @@ struct RoPE::RoPEExecutorQwen : public RoPE::Executor {
351356
auto present_kv_len = t_cos.size(1);
352357

353358
parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) {
359+
size_t sincos_pos;
360+
if (gather) {
361+
if (gather.m_rank == 4) {
362+
sincos_pos = gather.at<int32_t>({b, h, p, 0}, true);
363+
} else {
364+
sincos_pos = gather.at<int32_t>({b, p}, true);
365+
}
366+
} else {
367+
sincos_pos = present_kv_len - seq_len + p;
368+
}
369+
354370
auto* src = t_src.ptr<T>(b, p, h * head_size);
355-
auto* cos = &t_cos.at<float>({b, present_kv_len - seq_len + p, h, 0}, true);
356-
auto* sin = &t_sin.at<float>({b, present_kv_len - seq_len + p, h, 0}, true);
371+
auto* cos = &t_cos.at<float>({b, sincos_pos, h, 0}, true);
372+
auto* sin = &t_sin.at<float>({b, sincos_pos, h, 0}, true);
357373
auto* dst = t_dst.ptr<T>(b, p, h);
358374

359375
if (m_rotaryKernel) {

0 commit comments

Comments
 (0)