Skip to content

Commit 00a7592

Browse files
committed
fix PA sliding_window issue: add sliding_window process for second token
1 parent f3e8381 commit 00a7592

File tree

1 file changed

+81
-27
lines changed

1 file changed

+81
-27
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,7 @@ struct MHAHelper {
12201220
for (size_t pq = 0; pq < q_len; pq++) {
12211221
for (size_t h = hq_beg; h < hq_end; h++) {
12221222
// apply attention mask & sofmax
1223+
float* score = _weight.ptr<float>(ithr, h - hq_beg, pq);
12231224
float* alibi_lookup = nullptr;
12241225
float alibi_slope = 0.F;
12251226
if (alibi_slopes) {
@@ -1230,24 +1231,50 @@ struct MHAHelper {
12301231
if (sinks) {
12311232
sink = &sinks.at<float>({batch_in_seq, h, pq, 0}, true);
12321233
}
1233-
attn_softmax_kernel<float>(_weight.ptr<float>(ithr, h - hq_beg, pq),
1234-
_weight.ptr<float>(ithr, h - hq_beg, pq),
1235-
_d_scale,
1236-
alibi_lookup,
1237-
nullptr,
1238-
nullptr,
1239-
false,
1240-
cur_kv_len,
1241-
cur_kv_len,
1242-
ov::element::f32,
1243-
ov::element::f32,
1244-
sink,
1245-
alibi_slope);
1234+
if (_sliding_window) {
1235+
size_t start_idx = 0;
1236+
size_t new_causal = cur_kv_len;
1237+
float* sw_alibi_lookup = nullptr;
1238+
if (cur_kv_len > _sliding_window) {
1239+
start_idx = cur_kv_len - _sliding_window;
1240+
new_causal = _sliding_window;
1241+
}
1242+
attn_softmax_kernel<float>(score + start_idx,
1243+
score + start_idx,
1244+
_d_scale,
1245+
sw_alibi_lookup,
1246+
nullptr,
1247+
nullptr,
1248+
false,
1249+
new_causal,
1250+
cur_kv_len - start_idx,
1251+
ov::element::f32,
1252+
ov::element::f32,
1253+
sink,
1254+
alibi_slope);
1255+
if (start_idx > 0) {
1256+
memset(score, 0, sizeof(float) * start_idx);
1257+
}
1258+
} else {
1259+
attn_softmax_kernel<float>(score,
1260+
score,
1261+
_d_scale,
1262+
alibi_lookup,
1263+
nullptr,
1264+
nullptr,
1265+
false,
1266+
cur_kv_len,
1267+
cur_kv_len,
1268+
ov::element::f32,
1269+
ov::element::f32,
1270+
sink,
1271+
alibi_slope);
1272+
}
12461273
if (score_output) {
12471274
// aligned to cache line to avoid false sharing
12481275
static constexpr int cache_line_size = dnnl::impl::cpu::platform::get_cache_line_size();
12491276
std::memcpy(score_output + h * rnd_up(cur_kv_len, cache_line_size / sizeof(float)),
1250-
_weight.ptr<float>(ithr, h - hq_beg, pq),
1277+
score,
12511278
cur_kv_len * sizeof(float));
12521279
}
12531280
}
@@ -1403,6 +1430,7 @@ struct MHAHelper {
14031430
auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
14041431
auto ncausal = cur_kv_len;
14051432
// apply attention mask & sofmax
1433+
float* score = _weight_bhl.ptr<float>(b, h, pq);
14061434
float* alibi_lookup = nullptr;
14071435
float alibi_slope = 0.F;
14081436
if (alibi_slopes) {
@@ -1413,19 +1441,45 @@ struct MHAHelper {
14131441
if (sinks) {
14141442
sink = &sinks.at<float>({b, h, pq, 0}, true);
14151443
}
1416-
attn_softmax_kernel<float>(_weight_bhl.ptr<float>(b, h, pq),
1417-
_weight_bhl.ptr<float>(b, h, pq),
1418-
_d_scale,
1419-
alibi_lookup,
1420-
nullptr,
1421-
nullptr,
1422-
false,
1423-
ncausal,
1424-
cur_kv_len,
1425-
ov::element::f32,
1426-
ov::element::f32,
1427-
sink,
1428-
alibi_slope);
1444+
if (_sliding_window) {
1445+
size_t start_idx = 0;
1446+
size_t new_causal = ncausal;
1447+
float* sw_alibi_lookup = nullptr;
1448+
if (ncausal > _sliding_window) {
1449+
start_idx = ncausal - _sliding_window;
1450+
new_causal = _sliding_window;
1451+
}
1452+
attn_softmax_kernel<float>(score + start_idx,
1453+
score + start_idx,
1454+
_d_scale,
1455+
sw_alibi_lookup,
1456+
nullptr,
1457+
nullptr,
1458+
false,
1459+
new_causal,
1460+
cur_kv_len - start_idx,
1461+
ov::element::f32,
1462+
ov::element::f32,
1463+
sink,
1464+
alibi_slope);
1465+
if (start_idx > 0) {
1466+
memset(score, 0, sizeof(float) * start_idx);
1467+
}
1468+
} else {
1469+
attn_softmax_kernel<float>(score,
1470+
score,
1471+
_d_scale,
1472+
alibi_lookup,
1473+
nullptr,
1474+
nullptr,
1475+
false,
1476+
ncausal,
1477+
cur_kv_len,
1478+
ov::element::f32,
1479+
ov::element::f32,
1480+
sink,
1481+
alibi_slope);
1482+
}
14291483
};
14301484

14311485
size_t h_dims = loop_hk ? Hk : H;

0 commit comments

Comments
 (0)