Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) { \
size_t mask_idx = (i + n * vec_len_f32_avx512) / sparse_block_size; \
uint8_t mask_val = sparse_mask[mask_idx]; \
__m512 v_mask_block = _mm512_set1_ps(mask_val ? 0.f : -FLT_MAX); \
v_a = _mm512_add_ps(v_a, v_mask_block); \
if (!mask_val) { \
v_a = v_nfltmax; \
} \
} \
if (has_causal_mask) { \
auto v_maski8 = \
Expand Down Expand Up @@ -355,8 +356,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
__m512 v_mask_block = _mm512_set1_ps(mask_val ? 0.f : -FLT_MAX);
v_a = _mm512_add_ps(v_a, v_mask_block);
if (!mask_val) {
v_a = v_nfltmax;
}
}

if (has_causal_mask) {
Expand Down Expand Up @@ -390,8 +392,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
__m512 v_mask_block = _mm512_set1_ps(mask_val ? 0.f : -FLT_MAX);
v_a = _mm512_add_ps(v_a, v_mask_block);
if (!mask_val) {
v_a = v_nfltmax;
}
}

if (has_causal_mask) {
Expand Down Expand Up @@ -439,8 +442,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) { \
size_t mask_idx = (i + n * vec_len_f32_avx2) / sparse_block_size; \
uint8_t mask_val = sparse_mask[mask_idx]; \
__m256 v_mask_block = _mm256_set1_ps(mask_val ? 0.f : -FLT_MAX); \
v_a = _mm256_add_ps(v_a, v_mask_block); \
if (!mask_val) { \
v_a = v_nfltmax; \
} \
} \
if (has_causal_mask) { \
auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i + n * vec_len_f32_avx2)); \
Expand Down Expand Up @@ -476,8 +480,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
__m256 v_mask_block = _mm256_set1_ps(mask_val ? 0.f : -FLT_MAX);
v_a = _mm256_add_ps(v_a, v_mask_block);
if (!mask_val) {
v_a = v_nfltmax;
}
}

if (has_causal_mask) {
Expand Down Expand Up @@ -512,8 +517,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
__m256 v_mask_block = _mm256_set1_ps(mask_val ? 0.f : -FLT_MAX);
v_a = _mm256_add_ps(v_a, v_mask_block);
if (!mask_val) {
v_a = v_nfltmax;
}
}

if (has_causal_mask) {
Expand Down Expand Up @@ -560,8 +566,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
float32x4_t v_mask_block = vdupq_n_f32(mask_val ? 0.0F : -FLT_MAX);
v_a = vaddq_f32(v_a, v_mask_block);
if (!mask_val) {
v_a = v_nfltmax;
}
}

if (has_causal_mask) {
Expand Down Expand Up @@ -596,7 +603,9 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
a[i] += (mask_val ? 0.0F : -FLT_MAX);
if (!mask_val) {
a[i] = -FLT_MAX;
}
}

if (has_causal_mask) {
Expand Down
87 changes: 87 additions & 0 deletions src/plugins/intel_cpu/tests/unit/softmax_kernel_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "nodes/kernels/scaled_attn/softmax_kernel.hpp"

#include <cmath>
#include <vector>

#include "gtest/gtest.h"

namespace {
TEST(SoftmaxKernelTest, AttnSoftmaxKernelWithSparseMask) {
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
std::vector<float> output(input.size(), 0.0f);
std::vector<uint8_t> sparse_mask = {1, 0, 1, 0}; // Masking some elements, block size 2
float scale = 1.0f;
float* alibi = nullptr;
void* attn_mask = nullptr;
uint8_t* causal_mask = nullptr;
bool select_nfltmax_at_0 = false;
size_t len = input.size();
size_t total_size = input.size();
ov::element::Type attn_mask_prec = ov::element::f32;
ov::element::Type dst_precision = ov::element::f32;
const float* sink = nullptr;
float alibi_slope = 0.0f;
size_t sparse_block_size = 2;
ov::Extensions::Cpu::XARCH::attn_softmax_kernel<float>(input.data(),
output.data(),
scale,
alibi,
attn_mask,
causal_mask,
select_nfltmax_at_0,
len,
total_size,
attn_mask_prec,
dst_precision,
sink,
alibi_slope,
sparse_mask.data(),
sparse_block_size);
std::vector<float> expect_output = {0.00483724f, 0.013149f, 0.0f, 0.0f, 0.264104f, 0.71791f, 0.0f, 0.0f};
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_NEAR(output[i], expect_output[i], 1e-5f);
}
}

TEST(SoftmaxKernelTest, AttnSoftmaxKernelWithNaNInputAndSparseMask) {
std::vector<float> input = {1.0f, 2.0f, std::nanf(""), 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
std::vector<float> output(input.size(), 0.0f);
std::vector<uint8_t> sparse_mask = {1, 0, 1, 0}; // Masking some elements, block size 2
float scale = 1.0f;
float* alibi = nullptr;
void* attn_mask = nullptr;
uint8_t* causal_mask = nullptr;
bool select_nfltmax_at_0 = false;
size_t len = input.size();
size_t total_size = input.size();
ov::element::Type attn_mask_prec = ov::element::f32;
ov::element::Type dst_precision = ov::element::f32;
const float* sink = nullptr;
float alibi_slope = 0.0f;
size_t sparse_block_size = 2;
ov::Extensions::Cpu::XARCH::attn_softmax_kernel<float>(input.data(),
output.data(),
scale,
alibi,
attn_mask,
causal_mask,
select_nfltmax_at_0,
len,
total_size,
attn_mask_prec,
dst_precision,
sink,
alibi_slope,
sparse_mask.data(),
sparse_block_size);
std::vector<float> expect_output = {0.00483724f, 0.013149f, 0.0f, 0.0f, 0.264104f, 0.71791f, 0.0f, 0.0f};
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_NEAR(output[i], expect_output[i], 1e-5f);
}
}

} // namespace
Loading