Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,79 +33,6 @@

namespace ov::Extensions::Cpu::XARCH {

#if defined(OPENVINO_ARCH_ARM64)
namespace detail {
inline void zero_out_dst(void* a_dst, ov::element::Type dst_precision, size_t total_size) {
if (total_size == 0) {
return;
}
if (dst_precision == ov::element::f32) {
memset(static_cast<float*>(a_dst), 0, sizeof(float) * total_size);
} else {
memset(static_cast<ov::float16*>(a_dst), 0, sizeof(ov::float16) * total_size);
}
}

inline bool handle_empty_len(size_t len, void* a_dst, ov::element::Type dst_precision, size_t total_size) {
if (len != 0) {
return false;
}
zero_out_dst(a_dst, dst_precision, total_size);
return true;
}

template <typename T>
inline float to_float(T value) {
return static_cast<float>(value);
}

template <>
inline float to_float<float>(float value) {
return value;
}

template <typename T>
inline bool handle_inf_logits(const T* a,
void* a_dst,
ov::element::Type dst_precision,
size_t len,
size_t total_size,
const float* sink) {
size_t inf_count = 0;
if (sink != nullptr && std::isinf(*sink) && *sink > 0.0F) {
inf_count++;
}
for (size_t i = 0; i < len; i++) {
const float aval = to_float(a[i]);
if (std::isinf(aval) && aval > 0.0F) {
inf_count++;
}
}
const float inv = inf_count ? (1.0F / static_cast<float>(inf_count)) : 0.0F;
if (dst_precision == ov::element::f32) {
auto* dst = static_cast<float*>(a_dst);
for (size_t i = 0; i < len; i++) {
const float aval = to_float(a[i]);
dst[i] = (inf_count && std::isinf(aval) && aval > 0.0F) ? inv : 0.0F;
}
if (total_size > len) {
memset(dst + len, 0, sizeof(float) * (total_size - len));
}
} else {
auto* dst = static_cast<ov::float16*>(a_dst);
for (size_t i = 0; i < len; i++) {
const float aval = to_float(a[i]);
dst[i] = (inf_count && std::isinf(aval) && aval > 0.0F) ? ov::float16(inv) : ov::float16(0.0F);
}
if (total_size > len) {
memset(dst + len, 0, sizeof(ov::float16) * (total_size - len));
}
}
return true;
}
} // namespace detail
#endif

#if defined(HAVE_AVX2)
inline void exp_ps_avx2(__m256& src) {
# define REPEAT8(x) x, x, x, x, x, x, x, x
Expand All @@ -129,7 +56,7 @@ inline void exp_ps_avx2(__m256& src) {
__m256 exp_log2ef = _mm256_loadu_ps(reinterpret_cast<const float*>(c_e)); // log2(e)
__m256 half = _mm256_loadu_ps(reinterpret_cast<const float*>(c_half)); // 0.5f
__m256 ln2f = _mm256_loadu_ps(reinterpret_cast<const float*>(c_ln2)); // ln(2)
__m256 one = _mm256_loadu_ps(reinterpret_cast<const float*>(c_1)); // 1.0F
__m256 one = _mm256_loadu_ps(reinterpret_cast<const float*>(c_1)); // 1.0f
__m256i exponent_bias = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(c_bias)); // 127
__m256 exp_pol1 = _mm256_loadu_ps(reinterpret_cast<const float*>(c_p1)); // p1 = 0.999999701f
__m256 exp_pol2 = _mm256_loadu_ps(reinterpret_cast<const float*>(c_p2)); // p2 = 0.499991506f
Expand Down Expand Up @@ -566,9 +493,8 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
if (!mask_val) {
v_a = v_nfltmax;
}
float32x4_t v_mask_block = vdupq_n_f32(mask_val ? 0.f : -FLT_MAX);
v_a = vaddq_f32(v_a, v_mask_block);
}

if (has_causal_mask) {
Expand All @@ -577,9 +503,9 @@ inline void scale_add2_reduce_max(float* a,
uint32x4_t v_maski32_low = vmovl_u16(vget_low_u16(v_maski16));
uint32x4_t v_maski32_high = vmovl_u16(vget_high_u16(v_maski16));
uint32x4_t v_maski32[2] = {v_maski32_low, v_maski32_high};
for (const auto& mask_vec : v_maski32) {
uint32x4_t kmask = vceqq_u32(mask_vec, v_zeroi32); // ==0
v_a = vbslq_f32(kmask, v_nfltmax, v_a); // mask => -FLT_MAX
for (int j = 0; j < 2; ++j) {
uint32x4_t kmask = vceqq_u32(v_maski32[j], v_zeroi32); // ==0
v_a = vbslq_f32(kmask, v_nfltmax, v_a); // mask => -FLT_MAX
}
}

Expand All @@ -603,9 +529,7 @@ inline void scale_add2_reduce_max(float* a,
if (has_sparse_mask) {
size_t mask_idx = i / sparse_block_size;
uint8_t mask_val = sparse_mask[mask_idx];
if (!mask_val) {
a[i] = -FLT_MAX;
}
a[i] += (mask_val ? 0.f : -FLT_MAX);
}

if (has_causal_mask) {
Expand Down Expand Up @@ -667,7 +591,7 @@ inline void scale_add2_reduce_max(ov::float16* a,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ov::float16>,
"attn_mask must be float or float16 type.");
if constexpr (std::is_same_v<T, float>) {
svfloat16_t zero = svdup_n_f16(0.0F);
svfloat16_t zero = svdup_n_f16(0.0f);
size_t inc_low = (vec_len + 1) / 2;
size_t inc_high = vec_len / 2;
svbool_t pg_f32_low = svwhilelt_b32(0, static_cast<int>(inc_low));
Expand Down Expand Up @@ -798,7 +722,7 @@ static inline void exp_ps_avx512(__m512& src) {
__m512 exp_log2ef = _mm512_loadu_ps(reinterpret_cast<const float*>(c_e)); // log2(e)
__m512 half = _mm512_loadu_ps(reinterpret_cast<const float*>(c_half)); // 0.5f
__m512 ln2f = _mm512_loadu_ps(reinterpret_cast<const float*>(c_ln2)); // ln(2)
__m512 one = _mm512_loadu_ps(reinterpret_cast<const float*>(c_1)); // 1.0F
__m512 one = _mm512_loadu_ps(reinterpret_cast<const float*>(c_1)); // 1.0f
__m512i exponent_bias = _mm512_loadu_si512(c_bias); // 127
__m512 exp_pol1 = _mm512_loadu_ps(reinterpret_cast<const float*>(c_p1)); // p1 = 0.999999701f
__m512 exp_pol2 = _mm512_loadu_ps(reinterpret_cast<const float*>(c_p2)); // p2 = 0.499991506f
Expand Down Expand Up @@ -864,7 +788,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&
#if defined(HAVE_AVX512F)
__m512 v_a;
auto v_max = _mm512_set1_ps(max);
auto v_sum = _mm512_set1_ps(0.0F);
auto v_sum = _mm512_set1_ps(0.0f);
while (i + vec_len_f32_avx512 <= size) {
v_a = _mm512_loadu_ps(a + i);
v_a = _mm512_sub_ps(v_a, v_max);
Expand All @@ -888,7 +812,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&
#elif defined(HAVE_AVX2)
__m256 v_a;
auto v_max = _mm256_set1_ps(max);
auto v_sum = _mm256_set1_ps(0.0F);
auto v_sum = _mm256_set1_ps(0.0f);
while (i + vec_len_f32_avx2 <= size) {
v_a = _mm256_loadu_ps(a + i);
v_a = _mm256_sub_ps(v_a, v_max);
Expand All @@ -915,7 +839,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&
# if defined(HAVE_SVE)
svfloat32_t v_a;
svfloat32_t v_max = svdup_n_f32(max);
svfloat32_t v_sum = svdup_n_f32(0.0F);
svfloat32_t v_sum = svdup_n_f32(0.0f);
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();

Expand All @@ -935,7 +859,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&
# else
float32x4_t v_a;
float32x4_t v_max = vdupq_n_f32(max);
float32x4_t v_sum = vdupq_n_f32(0.0F);
float32x4_t v_sum = vdupq_n_f32(0.0f);

while (i + vec_len_f32_neon <= size) {
v_a = vld1q_f32(a + i);
Expand All @@ -959,7 +883,7 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size
size_t i = 0;
# if defined(HAVE_SVE)
svfloat32_t v_max = svdup_n_f32(static_cast<float>(max));
svfloat32_t v_sum = svdup_n_f32(0.0F);
svfloat32_t v_sum = svdup_n_f32(0.0f);

svbool_t pg_f32 = svptrue_b32();
svbool_t pg_f16 = svptrue_b16();
Expand Down Expand Up @@ -988,7 +912,7 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size
# else
float32x4_t v_a;
float32x4_t v_max = vdupq_n_f32(static_cast<float>(max));
float32x4_t v_sum = vdupq_n_f32(0.0F);
float32x4_t v_sum = vdupq_n_f32(0.0f);

// Process 4 FP32 elements at a time
for (; i + vec_len_f32_neon <= size; i += vec_len_f32_neon) {
Expand All @@ -1011,7 +935,7 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size
# endif
// Handle remaining elements
for (; i < size; ++i) {
const float val = std::exp(static_cast<float>(a[i] - max));
float val = exp(static_cast<float>(a[i] - max));
a[i] = static_cast<ov::float16>(val);
total_sum += val;
}
Expand All @@ -1026,7 +950,7 @@ inline void exp_reduce_sum(ov::float16* a, const ov::float16 max, const size_t s
# if defined(HAVE_SVE)
svfloat16_t v_a;
svfloat16_t v_max = svdup_n_f16(max);
svfloat16_t v_sum = svdup_n_f16(0.0F);
svfloat16_t v_sum = svdup_n_f16(0.0f);
svbool_t pg = svptrue_b16();
size_t inc = vec_len_f16_sve();

Expand All @@ -1047,7 +971,7 @@ inline void exp_reduce_sum(ov::float16* a, const ov::float16 max, const size_t s
const size_t vec_len_f16_neon = 8;
float16x8_t v_a;
float16x8_t v_max = vdupq_n_f16(max);
float16x8_t v_sum = vdupq_n_f16(0.0F);
float16x8_t v_sum = vdupq_n_f16(0.0f);

for (; i + vec_len_f16_neon <= size; i += vec_len_f16_neon) {
v_a = vld1q_f16(reinterpret_cast<const float16_t*>(a + i));
Expand Down Expand Up @@ -1138,7 +1062,7 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_
}

template <typename T, typename = std::enable_if_t<ov::intel_cpu::any_of_v<T, ov::bfloat16, ov::float16>>>
inline void multiply_scalar(const float* a, T* a_dst, const float val, const size_t size) {
inline void multiply_scalar(float* a, T* a_dst, const float val, const size_t size) {
size_t i = 0;
#if defined(HAVE_AVX512F)
auto v_scale = _mm512_set1_ps(val);
Expand Down Expand Up @@ -1181,7 +1105,7 @@ inline void multiply_scalar(ov::float16* a, float* a_dst, const ov::float16 val,
}

for (; i < size; ++i) {
const auto a_f32 = static_cast<float>(a[i]);
float a_f32 = static_cast<float>(a[i]);
a_dst[i] = a_f32 * static_cast<float>(val);
}
}
Expand Down Expand Up @@ -1290,11 +1214,6 @@ inline void attn_softmax_kernel<float>(float* a,
float&,
const uint8_t*,
size_t);
#if defined(OPENVINO_ARCH_ARM64)
if (detail::handle_empty_len(len, a_dst, dst_precision, total_size)) {
return;
}
#endif
using func_f16_type = void (*)(float*,
float,
const float*,
Expand Down Expand Up @@ -1395,23 +1314,17 @@ inline void attn_softmax_kernel<float>(float* a,
sparse_block_size);
}

float sum = 0.0F;
float sum = 0.0f;
if (sink != nullptr) {
max = max > (*sink) ? max : (*sink);
}
#if defined(OPENVINO_ARCH_ARM64)
if (std::isinf(max) && max > 0.0F) {
detail::handle_inf_logits(a, a_dst, dst_precision, len, total_size, sink);
return;
}
#endif
// exp sum
exp_reduce_sum(a, max, len, sum);
if (sink != nullptr) {
sum += std::exp(*sink - max);
}
// divide sum
float scalar = 1.0F / sum;
float scalar = 1.0f / sum;
if (dst_precision == ov::element::f32) {
multiply_scalar(a, reinterpret_cast<float*>(a_dst), scalar, len);
// apply causual mask to final result instead of attn_score
Expand Down Expand Up @@ -1505,11 +1418,6 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
scale_add2_reduce_max<true, true, false>,
scale_add2_reduce_max<true, true, true>};
int dispatch = (alibi ? 0b100 : 0) | (attn_mask ? 0b010 : 0) | (causal_mask ? 0b001 : 0);
# if defined(OPENVINO_ARCH_ARM64)
if (detail::handle_empty_len(len, a_dst, dst_precision, total_size)) {
return;
}
# endif
ov::float16 max = std::numeric_limits<ov::float16>::lowest();
if (attn_mask_prec == ov::element::f16) {
funcs_fp16[dispatch](a,
Expand Down Expand Up @@ -1545,28 +1453,17 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
max);
}

ov::float16 sum = 0.0F;
if (sink != nullptr) {
max = std::max(max, static_cast<const ov::float16>(*sink));
}
# if defined(OPENVINO_ARCH_ARM64)
const float max_f = static_cast<float>(max);
if (std::isinf(max_f) && max_f > 0.0F) {
detail::handle_inf_logits(a, a_dst, dst_precision, len, total_size, sink);
return;
}
# endif
exp_reduce_sum_f32(a, max, len, sum);
if (sink != nullptr) {
sum += std::exp(*sink - max);
}
ov::float16 scalar = 1.0F / sum;
ov::float16 sum = 0.0f;
if (dst_precision == ov::element::f32) {
exp_reduce_sum_f32(a, max, len, sum);
ov::float16 scalar = 1.0f / sum;
multiply_scalar(a, static_cast<float*>(a_dst), scalar, len);
// apply causual mask to final result instead of attn_score
if (total_size > len)
memset(static_cast<float*>(a_dst) + len, 0, sizeof(float) * (total_size - len));
} else {
exp_reduce_sum_f32(a, max, len, sum);
ov::float16 scalar = 1.0f / sum;
multiply_scalar_f32(a, static_cast<ov::float16*>(a_dst), scalar, len);
// apply causual mask to final result instead of attn_score
if (total_size > len)
Expand All @@ -1575,4 +1472,4 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
}
#endif

} // namespace ov::Extensions::Cpu::XARCH
} // namespace ov::Extensions::Cpu::XARCH
Loading
Loading