@@ -1444,23 +1444,19 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
14441444
14451445 ov::float16 sum = 0 .0f ;
14461446 if (sink != nullptr ) {
1447- max = max > static_cast <const ov::float16>(*sink) ? max : static_cast <const ov::float16>(*sink);
1447+ max = std::max (max, static_cast <const ov::float16>(*sink));
1448+ }
1449+ exp_reduce_sum_f32 (a, max, len, sum);
1450+ if (sink != nullptr ) {
1451+ sum += std::exp (*sink - max);
14481452 }
14491453 if (dst_precision == ov::element::f32 ) {
1450- exp_reduce_sum_f32 (a, max, len, sum);
1451- if (sink != nullptr ) {
1452- sum += std::exp (*sink - max);
1453- }
14541454 ov::float16 scalar = 1 .0f / sum;
14551455 multiply_scalar (a, static_cast <float *>(a_dst), scalar, len);
14561456 // apply causual mask to final result instead of attn_score
14571457 if (total_size > len)
14581458 memset (static_cast <float *>(a_dst) + len, 0 , sizeof (float ) * (total_size - len));
14591459 } else {
1460- exp_reduce_sum_f32 (a, max, len, sum);
1461- if (sink != nullptr ) {
1462- sum += std::exp (*sink - max);
1463- }
14641460 ov::float16 scalar = 1 .0f / sum;
14651461 multiply_scalar_f32 (a, static_cast <ov::float16*>(a_dst), scalar, len);
14661462 // apply causual mask to final result instead of attn_score
0 commit comments