Skip to content

Commit 4f28ab0

Browse files
committed
Aarch64 pa ACL f16 enablement
1 parent 138699e commit 4f28ab0

File tree

5 files changed

+500
-5
lines changed

5 files changed

+500
-5
lines changed

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,16 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
7272
const auto pa_op = m.get_match_root();
7373
auto key_cache = ov::as_type_ptr<ov::op::v0::Parameter>(pa_op->get_input_node_shared_ptr(3));
7474
auto value_cache = ov::as_type_ptr<ov::op::v0::Parameter>(pa_op->get_input_node_shared_ptr(4));
75+
#if defined(OPENVINO_ARCH_ARM64)
76+
auto format_cache_precision = [](ov::element::Type cache_precision, ov::element::Type infer_precision) {
77+
return ov::element::u8;
78+
};
79+
#else
7580
auto format_cache_precision = [](ov::element::Type cache_precision, ov::element::Type infer_precision) {
7681
return cache_precision == ov::element::f16 && infer_precision == ov::element::bf16 ? infer_precision
7782
: cache_precision;
7883
};
84+
#endif
7985
auto init_cache_shape = [&](const size_t head_nums,
8086
const size_t head_size,
8187
const size_t block_size,
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (C) 2024 FUJITSU LIMITED
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include <arm_sve.h>
5+
6+
#include "openvino/core/type/float16.hpp"
7+
8+
#define SIZE_IN_BITS(t_var) sizeof(t_var) * 8
9+
#define __ce(expr, bits, ...) \
10+
if constexpr (expr == bits) { \
11+
__VA_ARGS__ \
12+
}
13+
14+
#define SVE_PREDICATE(var, t_var) \
15+
svbool_t var; \
16+
\
17+
__ce(SIZE_IN_BITS(t_var), 8, var = svptrue_b8();) __ce(SIZE_IN_BITS(t_var), 16, var = svptrue_b16();) \
18+
__ce(SIZE_IN_BITS(t_var), 32, var = svptrue_b32();) __ce(SIZE_IN_BITS(t_var), 64, var = svptrue_b64();)
19+
20+
#define SVE_VLEN(var, t_var) \
21+
size_t var; \
22+
\
23+
__ce(SIZE_IN_BITS(t_var), 8, var = svcntb();) __ce(SIZE_IN_BITS(t_var), 16, var = svcnth();) \
24+
__ce(SIZE_IN_BITS(t_var), 32, var = svcntw();) __ce(SIZE_IN_BITS(t_var), 64, var = svcntd();)
25+
26+
#define SVE_PREDICATE_WHILELT(var, t_var, arg1, arg2) \
27+
svbool_t var; \
28+
\
29+
__ce(SIZE_IN_BITS(t_var), 8, var = svwhilelt_b8(arg1, arg2);) \
30+
__ce(SIZE_IN_BITS(t_var), 16, var = svwhilelt_b16(arg1, arg2);) \
31+
__ce(SIZE_IN_BITS(t_var), 32, var = svwhilelt_b32(arg1, arg2);) \
32+
__ce(SIZE_IN_BITS(t_var), 64, var = svwhilelt_b64(arg1, arg2);)
33+
34+
namespace ov::Extensions::Cpu::XARCH {
35+
static void cvt_copy(float* dst, ov::float16* src, size_t n) {
36+
auto src_ptr = reinterpret_cast<float16_t*>(src);
37+
auto pg_vl2 = svwhilelt_b16(svcnth() / 2, svcnth());
38+
auto vlen = svcnth() / 2;
39+
auto pg_dst = svptrue_b32();
40+
size_t i = 0;
41+
for (; i + vlen <= n; i += vlen) {
42+
auto load_src = svld1_f16(pg_vl2, src_ptr + i);
43+
auto src_interleave = svzip1_f16(load_src, load_src);
44+
auto cvt_dst = svcvt_f32_f16_z(pg_dst, src_interleave);
45+
svst1(pg_dst, dst + i, cvt_dst);
46+
}
47+
for (; i < n; i++) {
48+
dst[i] = src[i];
49+
}
50+
}
51+
52+
static void cvt_copy(ov::float16* dst, float* src, size_t n) {
53+
auto dst_ptr = reinterpret_cast<float16_t*>(dst);
54+
auto pg_src = svptrue_b32();
55+
auto pg_dst = svwhilelt_b16(svcnth() / 2, svcnth());
56+
auto vlen = svcntw();
57+
size_t i = 0;
58+
for (; i + vlen < n; i += vlen) {
59+
auto load_src = svld1_f32(pg_src, src + i);
60+
auto cvt_dst = svcvt_f16_f32_z(pg_src, load_src);
61+
auto str_dst = svuzp1(cvt_dst, cvt_dst);
62+
svst1_f16(pg_dst, dst_ptr + i, str_dst);
63+
}
64+
for (; i < n; i++) {
65+
dst[i] = src[i];
66+
}
67+
}
68+
} // namespace ov::Extensions::Cpu::XARCH

0 commit comments

Comments
 (0)