@@ -1220,6 +1220,7 @@ struct MHAHelper {
12201220 for (size_t pq = 0 ; pq < q_len; pq++) {
12211221 for (size_t h = hq_beg; h < hq_end; h++) {
12221222 // apply attention mask & sofmax
1223+ float * score = _weight.ptr <float >(ithr, h - hq_beg, pq);
12231224 float * alibi_lookup = nullptr ;
12241225 float alibi_slope = 0 .F ;
12251226 if (alibi_slopes) {
@@ -1230,24 +1231,50 @@ struct MHAHelper {
12301231 if (sinks) {
12311232 sink = &sinks.at <float >({batch_in_seq, h, pq, 0 }, true );
12321233 }
1233- attn_softmax_kernel<float >(_weight.ptr <float >(ithr, h - hq_beg, pq),
1234- _weight.ptr <float >(ithr, h - hq_beg, pq),
1235- _d_scale,
1236- alibi_lookup,
1237- nullptr ,
1238- nullptr ,
1239- false ,
1240- cur_kv_len,
1241- cur_kv_len,
1242- ov::element::f32 ,
1243- ov::element::f32 ,
1244- sink,
1245- alibi_slope);
1234+ if (_sliding_window) {
1235+ size_t start_idx = 0 ;
1236+ size_t new_causal = cur_kv_len;
1237+ float * sw_alibi_lookup = nullptr ;
1238+ if (cur_kv_len > _sliding_window) {
1239+ start_idx = cur_kv_len - _sliding_window;
1240+ new_causal = _sliding_window;
1241+ }
1242+ attn_softmax_kernel<float >(score + start_idx,
1243+ score + start_idx,
1244+ _d_scale,
1245+ sw_alibi_lookup,
1246+ nullptr ,
1247+ nullptr ,
1248+ false ,
1249+ new_causal,
1250+ cur_kv_len - start_idx,
1251+ ov::element::f32 ,
1252+ ov::element::f32 ,
1253+ sink,
1254+ alibi_slope);
1255+ if (start_idx > 0 ) {
1256+ memset (score, 0 , sizeof (float ) * start_idx);
1257+ }
1258+ } else {
1259+ attn_softmax_kernel<float >(score,
1260+ score,
1261+ _d_scale,
1262+ alibi_lookup,
1263+ nullptr ,
1264+ nullptr ,
1265+ false ,
1266+ cur_kv_len,
1267+ cur_kv_len,
1268+ ov::element::f32 ,
1269+ ov::element::f32 ,
1270+ sink,
1271+ alibi_slope);
1272+ }
12461273 if (score_output) {
12471274 // aligned to cache line to avoid false sharing
12481275 static constexpr int cache_line_size = dnnl::impl::cpu::platform::get_cache_line_size ();
12491276 std::memcpy (score_output + h * rnd_up (cur_kv_len, cache_line_size / sizeof (float )),
1250- _weight. ptr < float >(ithr, h - hq_beg, pq) ,
1277+ score ,
12511278 cur_kv_len * sizeof (float ));
12521279 }
12531280 }
@@ -1403,6 +1430,7 @@ struct MHAHelper {
14031430 auto cur_kv_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
14041431 auto ncausal = cur_kv_len;
14051432 // apply attention mask & sofmax
1433+ float * score = _weight_bhl.ptr <float >(b, h, pq);
14061434 float * alibi_lookup = nullptr ;
14071435 float alibi_slope = 0 .F ;
14081436 if (alibi_slopes) {
@@ -1413,19 +1441,45 @@ struct MHAHelper {
14131441 if (sinks) {
14141442 sink = &sinks.at <float >({b, h, pq, 0 }, true );
14151443 }
1416- attn_softmax_kernel<float >(_weight_bhl.ptr <float >(b, h, pq),
1417- _weight_bhl.ptr <float >(b, h, pq),
1418- _d_scale,
1419- alibi_lookup,
1420- nullptr ,
1421- nullptr ,
1422- false ,
1423- ncausal,
1424- cur_kv_len,
1425- ov::element::f32 ,
1426- ov::element::f32 ,
1427- sink,
1428- alibi_slope);
1444+ if (_sliding_window) {
1445+ size_t start_idx = 0 ;
1446+ size_t new_causal = ncausal;
1447+ float * sw_alibi_lookup = nullptr ;
1448+ if (ncausal > _sliding_window) {
1449+ start_idx = ncausal - _sliding_window;
1450+ new_causal = _sliding_window;
1451+ }
1452+ attn_softmax_kernel<float >(score + start_idx,
1453+ score + start_idx,
1454+ _d_scale,
1455+ sw_alibi_lookup,
1456+ nullptr ,
1457+ nullptr ,
1458+ false ,
1459+ new_causal,
1460+ cur_kv_len - start_idx,
1461+ ov::element::f32 ,
1462+ ov::element::f32 ,
1463+ sink,
1464+ alibi_slope);
1465+ if (start_idx > 0 ) {
1466+ memset (score, 0 , sizeof (float ) * start_idx);
1467+ }
1468+ } else {
1469+ attn_softmax_kernel<float >(score,
1470+ score,
1471+ _d_scale,
1472+ alibi_lookup,
1473+ nullptr ,
1474+ nullptr ,
1475+ false ,
1476+ ncausal,
1477+ cur_kv_len,
1478+ ov::element::f32 ,
1479+ ov::element::f32 ,
1480+ sink,
1481+ alibi_slope);
1482+ }
14291483 };
14301484
14311485 size_t h_dims = loop_hk ? Hk : H;
0 commit comments