Skip to content

Commit 65b105a

Browse files
authored
Fix sparse mask handling in softmax kernel (#33814)
### Details: - *Fix sparse mask handling in softmax kernel. In the sparse attention path, the sparse mask caused some blocks to be skipped, so those blocks are not written by the GEMM kernel, as a result, the corresponding regions in the output buffer remain uninitialized and their contents may decode to NAN/Inf values.* - *In this PR, we overwrite the skipped regions with -FLT_MAX to prevent NaN propagation and avoid incorrect computations in downstream kernels* ### Tickets: - *[CVS-179625](https://jira.devtools.intel.com/browse/CVS-179625)*
1 parent 51a9edd commit 65b105a

File tree

3 files changed

+113
-16
lines changed

3 files changed

+113
-16
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ inline void scale_add2_reduce_max(float* a,
318318
if (has_sparse_mask) { \
319319
size_t mask_idx = (i + n * vec_len_f32_avx512) / sparse_block_size; \
320320
uint8_t mask_val = sparse_mask[mask_idx]; \
321-
__m512 v_mask_block = _mm512_set1_ps(mask_val ? 0.f : -FLT_MAX); \
322-
v_a = _mm512_add_ps(v_a, v_mask_block); \
321+
if (!mask_val) { \
322+
v_a = v_nfltmax; \
323+
} \
323324
} \
324325
if (has_causal_mask) { \
325326
auto v_maski8 = \
@@ -355,8 +356,9 @@ inline void scale_add2_reduce_max(float* a,
355356
if (has_sparse_mask) {
356357
size_t mask_idx = i / sparse_block_size;
357358
uint8_t mask_val = sparse_mask[mask_idx];
358-
__m512 v_mask_block = _mm512_set1_ps(mask_val ? 0.f : -FLT_MAX);
359-
v_a = _mm512_add_ps(v_a, v_mask_block);
359+
if (!mask_val) {
360+
v_a = v_nfltmax;
361+
}
360362
}
361363

362364
if (has_causal_mask) {
@@ -390,8 +392,9 @@ inline void scale_add2_reduce_max(float* a,
390392
if (has_sparse_mask) {
391393
size_t mask_idx = i / sparse_block_size;
392394
uint8_t mask_val = sparse_mask[mask_idx];
393-
__m512 v_mask_block = _mm512_set1_ps(mask_val ? 0.f : -FLT_MAX);
394-
v_a = _mm512_add_ps(v_a, v_mask_block);
395+
if (!mask_val) {
396+
v_a = v_nfltmax;
397+
}
395398
}
396399

397400
if (has_causal_mask) {
@@ -439,8 +442,9 @@ inline void scale_add2_reduce_max(float* a,
439442
if (has_sparse_mask) { \
440443
size_t mask_idx = (i + n * vec_len_f32_avx2) / sparse_block_size; \
441444
uint8_t mask_val = sparse_mask[mask_idx]; \
442-
__m256 v_mask_block = _mm256_set1_ps(mask_val ? 0.f : -FLT_MAX); \
443-
v_a = _mm256_add_ps(v_a, v_mask_block); \
445+
if (!mask_val) { \
446+
v_a = v_nfltmax; \
447+
} \
444448
} \
445449
if (has_causal_mask) { \
446450
auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i + n * vec_len_f32_avx2)); \
@@ -476,8 +480,9 @@ inline void scale_add2_reduce_max(float* a,
476480
if (has_sparse_mask) {
477481
size_t mask_idx = i / sparse_block_size;
478482
uint8_t mask_val = sparse_mask[mask_idx];
479-
__m256 v_mask_block = _mm256_set1_ps(mask_val ? 0.f : -FLT_MAX);
480-
v_a = _mm256_add_ps(v_a, v_mask_block);
483+
if (!mask_val) {
484+
v_a = v_nfltmax;
485+
}
481486
}
482487

483488
if (has_causal_mask) {
@@ -512,8 +517,9 @@ inline void scale_add2_reduce_max(float* a,
512517
if (has_sparse_mask) {
513518
size_t mask_idx = i / sparse_block_size;
514519
uint8_t mask_val = sparse_mask[mask_idx];
515-
__m256 v_mask_block = _mm256_set1_ps(mask_val ? 0.f : -FLT_MAX);
516-
v_a = _mm256_add_ps(v_a, v_mask_block);
520+
if (!mask_val) {
521+
v_a = v_nfltmax;
522+
}
517523
}
518524

519525
if (has_causal_mask) {
@@ -560,8 +566,9 @@ inline void scale_add2_reduce_max(float* a,
560566
if (has_sparse_mask) {
561567
size_t mask_idx = i / sparse_block_size;
562568
uint8_t mask_val = sparse_mask[mask_idx];
563-
float32x4_t v_mask_block = vdupq_n_f32(mask_val ? 0.0F : -FLT_MAX);
564-
v_a = vaddq_f32(v_a, v_mask_block);
569+
if (!mask_val) {
570+
v_a = v_nfltmax;
571+
}
565572
}
566573

567574
if (has_causal_mask) {
@@ -596,7 +603,9 @@ inline void scale_add2_reduce_max(float* a,
596603
if (has_sparse_mask) {
597604
size_t mask_idx = i / sparse_block_size;
598605
uint8_t mask_val = sparse_mask[mask_idx];
599-
a[i] += (mask_val ? 0.0F : -FLT_MAX);
606+
if (!mask_val) {
607+
a[i] = -FLT_MAX;
608+
}
600609
}
601610

602611
if (has_causal_mask) {

src/plugins/intel_cpu/tests/unit/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ if(NOT X86_64)
3232
${CMAKE_CURRENT_SOURCE_DIR}/snippets_transformations/x64
3333
${CMAKE_CURRENT_SOURCE_DIR}/nodes/eltwise_node_test.cpp
3434
${CMAKE_CURRENT_SOURCE_DIR}/brgemm_executor_test.cpp
35-
${CMAKE_CURRENT_SOURCE_DIR}/xattention_test.cpp)
35+
${CMAKE_CURRENT_SOURCE_DIR}/xattention_test.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/softmax_kernel_test.cpp)
3637
endif()
3738

3839
if (NOT ENABLE_MLAS_FOR_CPU)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "nodes/kernels/scaled_attn/softmax_kernel.hpp"
6+
7+
#include <cmath>
8+
#include <vector>
9+
10+
#include "gtest/gtest.h"
11+
12+
namespace {
13+
TEST(SoftmaxKernelTest, AttnSoftmaxKernelWithSparseMask) {
14+
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
15+
std::vector<float> output(input.size(), 0.0f);
16+
std::vector<uint8_t> sparse_mask = {1, 0, 1, 0}; // Masking some elements, block size 2
17+
float scale = 1.0f;
18+
float* alibi = nullptr;
19+
void* attn_mask = nullptr;
20+
uint8_t* causal_mask = nullptr;
21+
bool select_nfltmax_at_0 = false;
22+
size_t len = input.size();
23+
size_t total_size = input.size();
24+
ov::element::Type attn_mask_prec = ov::element::f32;
25+
ov::element::Type dst_precision = ov::element::f32;
26+
const float* sink = nullptr;
27+
float alibi_slope = 0.0f;
28+
size_t sparse_block_size = 2;
29+
ov::Extensions::Cpu::XARCH::attn_softmax_kernel<float>(input.data(),
30+
output.data(),
31+
scale,
32+
alibi,
33+
attn_mask,
34+
causal_mask,
35+
select_nfltmax_at_0,
36+
len,
37+
total_size,
38+
attn_mask_prec,
39+
dst_precision,
40+
sink,
41+
alibi_slope,
42+
sparse_mask.data(),
43+
sparse_block_size);
44+
std::vector<float> expect_output = {0.00483724f, 0.013149f, 0.0f, 0.0f, 0.264104f, 0.71791f, 0.0f, 0.0f};
45+
for (size_t i = 0; i < output.size(); ++i) {
46+
EXPECT_NEAR(output[i], expect_output[i], 1e-5f);
47+
}
48+
}
49+
50+
TEST(SoftmaxKernelTest, AttnSoftmaxKernelWithNaNInputAndSparseMask) {
51+
std::vector<float> input = {1.0f, 2.0f, std::nanf(""), 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
52+
std::vector<float> output(input.size(), 0.0f);
53+
std::vector<uint8_t> sparse_mask = {1, 0, 1, 0}; // Masking some elements, block size 2
54+
float scale = 1.0f;
55+
float* alibi = nullptr;
56+
void* attn_mask = nullptr;
57+
uint8_t* causal_mask = nullptr;
58+
bool select_nfltmax_at_0 = false;
59+
size_t len = input.size();
60+
size_t total_size = input.size();
61+
ov::element::Type attn_mask_prec = ov::element::f32;
62+
ov::element::Type dst_precision = ov::element::f32;
63+
const float* sink = nullptr;
64+
float alibi_slope = 0.0f;
65+
size_t sparse_block_size = 2;
66+
ov::Extensions::Cpu::XARCH::attn_softmax_kernel<float>(input.data(),
67+
output.data(),
68+
scale,
69+
alibi,
70+
attn_mask,
71+
causal_mask,
72+
select_nfltmax_at_0,
73+
len,
74+
total_size,
75+
attn_mask_prec,
76+
dst_precision,
77+
sink,
78+
alibi_slope,
79+
sparse_mask.data(),
80+
sparse_block_size);
81+
std::vector<float> expect_output = {0.00483724f, 0.013149f, 0.0f, 0.0f, 0.264104f, 0.71791f, 0.0f, 0.0f};
82+
for (size_t i = 0; i < output.size(); ++i) {
83+
EXPECT_NEAR(output[i], expect_output[i], 1e-5f);
84+
}
85+
}
86+
87+
} // namespace

0 commit comments

Comments
 (0)