@@ -578,14 +578,15 @@ inline void scale_add2_reduce_max(ov::float16* a,
578578
579579 if (has_alibi) {
580580 svfloat16_t v_lookup = svld1_f16 (pg_f16, reinterpret_cast <const float16_t *>(alibi_lookup + i));
581- v_a = svmla_f16_x (pg_f16, v_a, v_lookup, v_alibi_slope);
581+ v_a = svmla_f16_z (pg_f16, v_a, v_lookup, v_alibi_slope);
582582 }
583583
584584 if (has_attn_mask) {
585585 static_assert (std::is_same_v<T, float > || std::is_same_v<T, ov::float16>,
586586 " attn_mask must be float or float16 type." );
587587 if constexpr (std::is_same_v<T, float >) {
588588 svfloat16_t zero = svdup_n_f16 (0 .0f );
589+ size_t inc_low = (vec_len + 1 ) / 2 ;
589590 size_t inc_high = vec_len / 2 ;
590591 svbool_t pg_f32_low = svwhilelt_b32 (0 , static_cast <int >(inc_low));
591592 svbool_t pg_f32_high = svwhilelt_b32 (0 , static_cast <int >(inc_high));
@@ -607,9 +608,13 @@ inline void scale_add2_reduce_max(ov::float16* a,
607608 v_a = svuzp1 (low_f16_out, high_f16_out);
608609 } else if constexpr (std::is_same_v<T, ov::float16>) {
609610 svfloat16_t v_mask = svld1_f16 (pg_f16, reinterpret_cast <const float16_t *>(attn_mask + i));
610- v_a = svadd_f16_z (pg_f16, v_a, v_mask);
611+ v_a = svadd_f16_x (pg_f16, v_a, v_mask);
611612 }
612613 }
614+
615+ if (has_causal_mask) {
616+ svuint8_t v_maski8 = svld1_u8 (pg_u8, causal_mask + i);
617+ svuint16_t v_maski16 = svtrn1_u16 (svreinterpret_u16_u8 (v_maski8), svdup_n_u16 (0 ));
613618 svbool_t kmask = svcmpeq_u16 (pg_u16, v_maski16, v_zeroi16);
614619 kmask = sveor_z (pg_u16, kmask, mask_xor);
615620 v_a = svsel_f16 (kmask, v_nfltmax, v_a);
@@ -1443,12 +1448,16 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
14431448 }
14441449
14451450 ov::float16 sum = 0 .0f ;
1451+ ov::float16 clamped_sink_value = 0 .0f ;
14461452 if (sink != nullptr ) {
1447- max = std::max (max, static_cast <const ov::float16>(*sink));
1453+ clamped_sink_value = static_cast <const ov::float16>(*sink);
1454+ clamped_sink_value =
1455+ std::isinf (clamped_sink_value) ? std::numeric_limits<ov::float16>::max () : clamped_sink_value;
1456+ max = std::max (max, clamped_sink_value);
14481457 }
14491458 exp_reduce_sum_f32 (a, max, len, sum);
14501459 if (sink != nullptr ) {
1451- sum += std::exp (*sink - max);
1460+ sum += std::exp (clamped_sink_value - max);
14521461 }
14531462 if (dst_precision == ov::element::f32 ) {
14541463 ov::float16 scalar = 1 .0f / sum;
0 commit comments