Skip to content

Commit 4885881

Browse files
committed
modified sve_utils, added Copyright info, rebased
1 parent 5c0d665 commit 4885881

File tree

4 files changed

+80
-61
lines changed

4 files changed

+80
-61
lines changed

src/plugins/intel_cpu/src/nodes/kernels/aarch64/pa_kernels.hpp

Lines changed: 0 additions & 51 deletions
This file was deleted.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (C) 2024 FUJITSU LIMITED
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include <arm_sve.h>
5+
6+
#include "openvino/core/type/float16.hpp"
7+
8+
template <typename T, typename... Args>
9+
constexpr bool one_of(T val, Args... args) {
10+
return ((val == args) || ...);
11+
}
12+
13+
template <size_t T_SIZE>
14+
svbool_t sve_predicate() {
15+
static_assert(one_of(T_SIZE, 8, 16, 32, 64), "Unexpected parameter size");
16+
if constexpr (8 == T_SIZE) {
17+
return svptrue_b8();
18+
} else if (16 == T_SIZE) {
19+
return svptrue_b16();
20+
} else if (32 == T_SIZE) {
21+
return svptrue_b32();
22+
} else if (64 == T_SIZE) {
23+
return svptrue_b64();
24+
}
25+
}
26+
27+
template <typename T_TYPE, size_t T_SIZE>
28+
svbool_t sve_predicate(T_TYPE lower, T_TYPE higher) {
29+
static_assert(one_of(T_SIZE, 8, 16, 32, 64), "Unexpected parameter size");
30+
if constexpr (8 == T_SIZE) {
31+
return svwhilelt_b8(lower, higher);
32+
} else if (16 == T_SIZE) {
33+
return svwhilelt_b16(lower, higher);
34+
} else if (32 == T_SIZE) {
35+
return svwhilelt_b32(lower, higher);
36+
} else if (64 == T_SIZE) {
37+
return svwhilelt_b64(lower, higher);
38+
}
39+
}
40+
41+
template <size_t T_SIZE>
42+
size_t sve_vlen() {
43+
static_assert(one_of(T_SIZE, 8, 16, 32, 64), "Unexpected parameter size");
44+
if constexpr (8 == T_SIZE) {
45+
return svcntb();
46+
} else if (16 == T_SIZE) {
47+
return svcnth();
48+
} else if (32 == T_SIZE) {
49+
return svcntw();
50+
} else if (64 == T_SIZE) {
51+
return svcntd();
52+
}
53+
}

src/plugins/intel_cpu/src/nodes/kernels/kai/kleidi_kernel.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
// Copyright (C) 2024 FUJITSU LIMITED
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
17
#include <arm_neon.h>
28
#include <kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h>
39
#include <kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h>

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#elif defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE)
3131
# include "arm_sve.h"
3232
# include "nodes/kernels/aarch64/brgemm_kernel.hpp"
33-
# include "nodes/kernels/aarch64/pa_kernels.hpp"
33+
# include "nodes/kernels/aarch64/sve_utils.hpp"
3434
# include "nodes/kernels/kai/kleidi_kernel.hpp"
3535
#endif
3636

@@ -39,7 +39,7 @@ namespace ov::Extensions::Cpu::XARCH {
3939
using namespace ov;
4040
using namespace ov::intel_cpu;
4141

42-
// currently depends on brgemm which only support x64
42+
// currently depends on brgemm which only support x64 or ARM SVE
4343
#if defined(OPENVINO_ARCH_X86_64) || (defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE))
4444

4545
# if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
@@ -72,16 +72,27 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
7272
}
7373
# elif defined(HAVE_SVE)
7474
if constexpr (std::is_same<TA, TB>::value) {
75-
SVE_PREDICATE(pg_dst, TA)
76-
SVE_VLEN(vlen, TA)
75+
auto pg_dst = sve_predicate<sizeof(TA)>();
76+
auto vlen = sve_vlen<sizeof(TA)>();
7777
for (; i + vlen <= n; i += vlen) {
7878
auto vb = svld1(pg_dst, src + i);
7979
svst1(pg_dst, dst + i, vb);
8080
}
81-
SVE_PREDICATE_WHILELT(pgt, TA, i, n)
81+
auto pgt = sve_predicate<TA, sizeof(TA)>(i, n);
8282
auto vb = svld1(pg_dst, src + i);
8383
svst1(pg_dst, dst + i, vb);
8484
return;
85+
} else if constexpr (std::is_same<TA, float>::value && std::is_same<TB, ov::float16>::value) {
86+
auto src_ptr = reinterpret_cast<float16_t*>(src);
87+
auto pg_vl2 = svwhilelt_b16(svcnth() / 2, svcnth());
88+
auto vlen = svcnth() / 2;
89+
auto pg_dst = svptrue_b32();
90+
for (; i + vlen <= n; i += vlen) {
91+
auto load_src = svld1_f16(pg_vl2, src_ptr + i);
92+
auto src_interleave = svzip1_f16(load_src, load_src);
93+
auto cvt_dst = svcvt_f32_f16_z(pg_dst, src_interleave);
94+
svst1(pg_dst, dst + i, cvt_dst);
95+
}
8596
}
8697
# endif
8798
for (; i < n; i++) {
@@ -1715,7 +1726,7 @@ struct MHAHelper {
17151726
void init_reorder_buffers(size_t batch, size_t kv_len_in_blocks) {
17161727
_qk_scratch_b.resize<DATA_TYPE>({batch, kv_len_in_blocks, Hk, _block_size * S});
17171728
if (AarchF16) {
1718-
// Required to keep kv_cache continuous in mem, as kleidi do to support accumulation
1729+
// It is required to keep kv_cache continuous in mem, as kleidi do not support accumulation
17191730
_wv_scratch_b.resize<DATA_TYPE>({batch, Hk, kv_len_in_blocks, _block_size * rnd_up(SV, _block_size)});
17201731
} else {
17211732
_wv_scratch_b.resize<DATA_TYPE>({batch, kv_len_in_blocks, Hk, _block_size * rnd_up(SV, _block_size)});
@@ -1918,7 +1929,7 @@ struct MHAHelper {
19181929
auto _score_stride = _weight.stride_bytes(2) / 2;
19191930
for (size_t h = hq_beg; h < hq_end; h++) {
19201931
auto* q_ptr = query.ptr<DATA_TYPE>(h, q_start, 0);
1921-
float* c_ptr = _weight.ptr<float>(ithr, h, 0, 0);
1932+
float* c_ptr = _weight.ptr<float>(ithr, h - hq_beg, 0, 0);
19221933
// for each query block, loop through all key block
19231934
// for blocks:
19241935
// 1 0 0 0 ...
@@ -1947,8 +1958,8 @@ struct MHAHelper {
19471958
for (size_t m = q_start; m < q_end; m++) {
19481959
// apply softmax in f32 precision
19491960
auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1);
1950-
auto soft_in = _weight.ptr<float>(ithr, h, m - q_start);
1951-
auto score = _weight.ptr<float>(ithr, h, m - q_start);
1961+
auto soft_in = _weight.ptr<float>(ithr, h - hq_beg, m - q_start);
1962+
auto score = _weight.ptr<float>(ithr, h - hq_beg, m - q_start);
19521963
PlainTensor f32_cvt;
19531964
if (q_is_xf16) {
19541965
f32_cvt.resize<float>({size_t{rnd_up(cur_kv_len, _block_size)}});
@@ -2007,7 +2018,7 @@ struct MHAHelper {
20072018
}
20082019

20092020
// reuse float buffer, need to use float to compute offset
2010-
auto* w_ptr = reinterpret_cast<DATA_TYPE*>(_weight.ptr<float>(ithr, h, 0, 0));
2021+
auto* w_ptr = reinterpret_cast<DATA_TYPE*>(_weight.ptr<float>(ithr, h - hq_beg, 0, 0));
20112022
DATA_TYPE* out_ptr = output_emb.ptr<DATA_TYPE>(q_start, h * SV);
20122023
DATA_TYPE* v_ptr;
20132024
v_ptr = wv_scratch_b.ptr<DATA_TYPE>(hk, 0);

0 commit comments

Comments
 (0)