Skip to content

Commit c084a1d

Browse files
committed
copilot suggested changes
1 parent b8f0759 commit c084a1d

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,23 +1444,19 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
14441444

14451445
ov::float16 sum = 0.0f;
14461446
if (sink != nullptr) {
1447-
max = max > static_cast<const ov::float16>(*sink) ? max : static_cast<const ov::float16>(*sink);
1447+
max = std::max(max, static_cast<const ov::float16>(*sink));
1448+
}
1449+
exp_reduce_sum_f32(a, max, len, sum);
1450+
if (sink != nullptr) {
1451+
sum += std::exp(*sink - max);
14481452
}
14491453
if (dst_precision == ov::element::f32) {
1450-
exp_reduce_sum_f32(a, max, len, sum);
1451-
if (sink != nullptr) {
1452-
sum += std::exp(*sink - max);
1453-
}
14541454
ov::float16 scalar = 1.0f / sum;
14551455
multiply_scalar(a, static_cast<float*>(a_dst), scalar, len);
14561456
// apply causual mask to final result instead of attn_score
14571457
if (total_size > len)
14581458
memset(static_cast<float*>(a_dst) + len, 0, sizeof(float) * (total_size - len));
14591459
} else {
1460-
exp_reduce_sum_f32(a, max, len, sum);
1461-
if (sink != nullptr) {
1462-
sum += std::exp(*sink - max);
1463-
}
14641460
ov::float16 scalar = 1.0f / sum;
14651461
multiply_scalar_f32(a, static_cast<ov::float16*>(a_dst), scalar, len);
14661462
// apply causual mask to final result instead of attn_score

0 commit comments

Comments
 (0)