Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1537,23 +1537,27 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
}

ov::float16 sum = 0.0F;
if (sink != nullptr) {
max = std::max(max, static_cast<const ov::float16>(*sink));
}
# if defined(OPENVINO_ARCH_ARM64)
const float max_f = static_cast<float>(max);
if (std::isinf(max_f) && max_f > 0.0F) {
detail::handle_inf_logits(a, a_dst, dst_precision, len, total_size, sink);
return;
}
# endif
exp_reduce_sum_f32(a, max, len, sum);
if (sink != nullptr) {
sum += std::exp(*sink - max);
}
ov::float16 scalar = 1.0F / sum;
if (dst_precision == ov::element::f32) {
exp_reduce_sum_f32(a, max, len, sum);
ov::float16 scalar = 1.0F / sum;
multiply_scalar(a, static_cast<float*>(a_dst), scalar, len);
// apply causual mask to final result instead of attn_score
if (total_size > len)
memset(static_cast<float*>(a_dst) + len, 0, sizeof(float) * (total_size - len));
} else {
exp_reduce_sum_f32(a, max, len, sum);
ov::float16 scalar = 1.0F / sum;
multiply_scalar_f32(a, static_cast<ov::float16*>(a_dst), scalar, len);
// apply causual mask to final result instead of attn_score
if (total_size > len)
Expand All @@ -1562,4 +1566,4 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
}
#endif

} // namespace ov::Extensions::Cpu::XARCH
} // namespace ov::Extensions::Cpu::XARCH
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
PlainTensor& output_emb,
bool has_out_transpose,
bool auto_causal,
[[maybe_unused]] PlainTensor& sink_input,
PlainTensor& sink_input,
float d_scale = 0.0F) {
auto B = query.size(0);
auto H = query.size(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();

// only support sink input on x86 platform currently.
#ifndef OPENVINO_ARCH_X86_64
if (pattern_map.count(atten_sink)) {
return false;
}
#endif
// Check concat axes equality first
const auto concat_k_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
Expand Down
Loading