@@ -1217,6 +1217,86 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
1217
1217
}
1218
1218
}
1219
1219
1220
+ TEST_F (TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) {
1221
+ using namespace ov ;
1222
+
1223
+ {
1224
+ auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1 , -1 });
1225
+ auto qkv = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1 , 1 , 3 * 4096 });
1226
+
1227
+ auto qkv_proj = makeOP<opset1::VariadicSplit>({qkv, 2 , {4096 , 4096 , -1 }});
1228
+
1229
+ auto view_Reshape = makeOP<opset1::Reshape>({qkv_proj->output (0 ), {0 , 0 , 32 , 128 }}, {{" special_zero" , true }});
1230
+ auto slice_Slice_4 = makeOP<opset8::Slice>({view_Reshape, {0 }, {128 }, {1 }, {3 }});
1231
+ auto slice_Slice = makeConst (element::f32, ov::Shape ({1 , 4096 , 1 , 128 }), {1 });
1232
+
1233
+ auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{" destination_type" , " i32" }});
1234
+ auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1 , 1 }}, {{" special_zero" , false }});
1235
+
1236
+ auto slice_Slice_1 = makeOP<opset8::Gather>({slice_Slice, Unsqueeze_23750, 1 }, {{" batch_dims" , 0 }});
1237
+ auto Reshape_27400 = makeOP<opset1::Reshape>({slice_Slice_1, {-1 , 1 , 1 , 128 }}, {{" special_zero" , false }});
1238
+
1239
+ auto mul_Multiply = makeOP<opset1::Multiply>({slice_Slice_4, Reshape_27400}, {{" auto_broadcast" , " numpy" }});
1240
+ auto reshape_Reshape = makeOP<opset1::Reshape>({slice_Slice_4, {0 , 0 , 32 , 2 , 64 }}, {{" special_zero" , true }});
1241
+ auto ListUnpack_Split = makeOP<opset1::Split>({reshape_Reshape, -2 }, {{" num_splits" , 2 }});
1242
+ auto Multiply_54136 =
1243
+ makeOP<opset1::Multiply>({ListUnpack_Split->output (1 ), -1 .000000f }, {{" auto_broadcast" , " numpy" }});
1244
+ auto ListUnpack_Squeeze_0 =
1245
+ makeOP<opset1::Reshape>({Multiply_54136, {-1 , 1 , 32 , 64 }}, {{" special_zero" , false }});
1246
+ auto ListUnpack_Squeeze =
1247
+ makeOP<opset1::Reshape>({ListUnpack_Split->output (0 ), {-1 , 1 , 32 , 64 }}, {{" special_zero" , false }});
1248
+ auto cat_Concat = makeOP<opset1::Concat>({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{" axis" , -1 }});
1249
+
1250
+ auto slice_Slice_2 = makeConst (element::f32, ov::Shape ({1 , 4096 , 1 , 128 }), {1 });
1251
+ auto slice_Slice_6 = makeOP<opset8::Gather>({slice_Slice_2, Unsqueeze_23750, 1 }, {{" batch_dims" , 0 }});
1252
+ auto Reshape_27408 = makeOP<opset1::Reshape>({slice_Slice_6, {-1 , 1 , 1 , 128 }}, {{" special_zero" , false }});
1253
+ auto mul_Multiply_1 = makeOP<opset1::Multiply>({cat_Concat, Reshape_27408}, {{" auto_broadcast" , " numpy" }});
1254
+ auto add_Add = makeOP<opset1::Add>({mul_Multiply, mul_Multiply_1}, {{" auto_broadcast" , " numpy" }});
1255
+
1256
+ auto slice_Slice_10 = makeConst (element::f32, ov::Shape ({1 , 32767 , 1 , 1 }), {1 });
1257
+ auto view_Reshape_1 = makeOP<opset1::Reshape>({qkv_proj->output (1 ), {0 , 0 , 32 , 128 }}, {{" special_zero" , true }});
1258
+ auto slice_Slice_11 = makeOP<opset8::Slice>({view_Reshape_1, {0 }, {128 }, {1 }, {3 }});
1259
+ auto mul_Multiply_2 = makeOP<opset1::Multiply>({slice_Slice_11, Reshape_27400}, {{" auto_broadcast" , " numpy" }});
1260
+ auto reshape_Reshape_1 = makeOP<opset1::Reshape>({slice_Slice_11, {0 , 0 , 32 , 2 , 64 }}, {{" special_zero" , true }});
1261
+ auto ListUnpack_Split_1 = makeOP<opset1::Split>({reshape_Reshape_1, -2 }, {{" num_splits" , 2 }});
1262
+ auto Multiply_54139 =
1263
+ makeOP<opset1::Multiply>({ListUnpack_Split_1->output (1 ), -1 .000000f }, {{" auto_broadcast" , " numpy" }});
1264
+ auto ListUnpack_Squeeze_0_1 =
1265
+ makeOP<opset1::Reshape>({Multiply_54139, {-1 , 1 , 32 , 64 }}, {{" special_zero" , false }});
1266
+ auto ListUnpack_Squeeze_1 =
1267
+ makeOP<opset1::Reshape>({ListUnpack_Split_1->output (0 ), {-1 , 1 , 32 , 64 }}, {{" special_zero" , false }});
1268
+ auto cat_Concat_2 = makeOP<opset1::Concat>({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{" axis" , -1 }});
1269
+ auto mul_Multiply_3 = makeOP<opset1::Multiply>({cat_Concat_2, Reshape_27408}, {{" auto_broadcast" , " numpy" }});
1270
+ auto add_Add_1 = makeOP<opset1::Add>({mul_Multiply_2, mul_Multiply_3}, {{" auto_broadcast" , " numpy" }});
1271
+ model = std::make_shared<ov::Model>(ov::NodeVector{add_Add_1}, ov::ParameterVector{position_ids, qkv});
1272
+ }
1273
+
1274
+ manager.register_pass <ov::pass::RoPEFusion>(false );
1275
+
1276
+ {
1277
+ auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1 , 1 , 4096 * 3 });
1278
+ auto rotary_emp_sin = makeConst (element::f32, ov::Shape ({1 , 4096 , 1 , 128 }), {1 });
1279
+ auto rotary_emp_cos = makeConst (element::f32, ov::Shape ({1 , 4096 , 1 , 128 }), {1 });
1280
+ auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1 , -1 });
1281
+ auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{" destination_type" , " i32" }});
1282
+ auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1 , 1 }}, {{" special_zero" , false }});
1283
+ auto rope = makeOP<ov::op::internal::RoPE>({input, rotary_emp_sin, rotary_emp_cos, Unsqueeze_23750},
1284
+ {{" config.slice_start" , 4096 },
1285
+ {" config.slice_stop" , 8192 },
1286
+ {" config.input_trans0213" , false },
1287
+ {" config.output_trans0213" , false },
1288
+ {" config.is_interleaved" , false },
1289
+ {" config.rotary_ndims" , 128 },
1290
+ {" config.is_chatglm" , false },
1291
+ {" config.support_2d_rope" , false },
1292
+ {" config.is_qwen" , true },
1293
+ {" config.head_cnt" , 32 },
1294
+ {" config.head_size" , 128 },
1295
+ {" config.gather_position_arg_id" , 3 }});
1296
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, position_ids});
1297
+ }
1298
+ }
1299
+
1220
1300
TEST_F (TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
1221
1301
disable_rt_info_check ();
1222
1302
const int batch = -1 ;
0 commit comments