Skip to content

Commit 8dcdae9

Browse files
committed
Rebase with master
1 parent c084a1d commit 8dcdae9

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,15 @@ inline void scale_add2_reduce_max(ov::float16* a,
578578

579579
if (has_alibi) {
580580
svfloat16_t v_lookup = svld1_f16(pg_f16, reinterpret_cast<const float16_t*>(alibi_lookup + i));
581-
v_a = svmla_f16_x(pg_f16, v_a, v_lookup, v_alibi_slope);
581+
v_a = svmla_f16_z(pg_f16, v_a, v_lookup, v_alibi_slope);
582582
}
583583

584584
if (has_attn_mask) {
585585
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ov::float16>,
586586
"attn_mask must be float or float16 type.");
587587
if constexpr (std::is_same_v<T, float>) {
588588
svfloat16_t zero = svdup_n_f16(0.0f);
589+
size_t inc_low = (vec_len + 1) / 2;
589590
size_t inc_high = vec_len / 2;
590591
svbool_t pg_f32_low = svwhilelt_b32(0, static_cast<int>(inc_low));
591592
svbool_t pg_f32_high = svwhilelt_b32(0, static_cast<int>(inc_high));
@@ -607,9 +608,13 @@ inline void scale_add2_reduce_max(ov::float16* a,
607608
v_a = svuzp1(low_f16_out, high_f16_out);
608609
} else if constexpr (std::is_same_v<T, ov::float16>) {
609610
svfloat16_t v_mask = svld1_f16(pg_f16, reinterpret_cast<const float16_t*>(attn_mask + i));
610-
v_a = svadd_f16_z(pg_f16, v_a, v_mask);
611+
v_a = svadd_f16_x(pg_f16, v_a, v_mask);
611612
}
612613
}
614+
615+
if (has_causal_mask) {
616+
svuint8_t v_maski8 = svld1_u8(pg_u8, causal_mask + i);
617+
svuint16_t v_maski16 = svtrn1_u16(svreinterpret_u16_u8(v_maski8), svdup_n_u16(0));
613618
svbool_t kmask = svcmpeq_u16(pg_u16, v_maski16, v_zeroi16);
614619
kmask = sveor_z(pg_u16, kmask, mask_xor);
615620
v_a = svsel_f16(kmask, v_nfltmax, v_a);
@@ -1443,12 +1448,16 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
14431448
}
14441449

14451450
ov::float16 sum = 0.0f;
1451+
ov::float16 clamped_sink_value = 0.0f;
14461452
if (sink != nullptr) {
1447-
max = std::max(max, static_cast<const ov::float16>(*sink));
1453+
clamped_sink_value = static_cast<const ov::float16>(*sink);
1454+
clamped_sink_value =
1455+
std::isinf(clamped_sink_value) ? std::numeric_limits<ov::float16>::max() : clamped_sink_value;
1456+
max = std::max(max, clamped_sink_value);
14481457
}
14491458
exp_reduce_sum_f32(a, max, len, sum);
14501459
if (sink != nullptr) {
1451-
sum += std::exp(*sink - max);
1460+
sum += std::exp(clamped_sink_value - max);
14521461
}
14531462
if (dst_precision == ov::element::f32) {
14541463
ov::float16 scalar = 1.0f / sum;

0 commit comments

Comments
 (0)