1+ // Copyright (C) 2024 FUJITSU LIMITED
2+ // SPDX-License-Identifier: Apache-2.0
3+ //
4+ #include < arm_sve.h>
5+
6+ #include " openvino/core/type/float16.hpp"
7+
8+ #define SIZE_IN_BITS (t_var ) sizeof (t_var) * 8
9+ #define __ce (expr, bits, ...) \
10+ if constexpr (expr == bits) { \
11+ __VA_ARGS__ \
12+ }
13+
14+ #define SVE_PREDICATE (var, t_var ) \
15+ svbool_t var; \
16+ \
17+ __ce (SIZE_IN_BITS(t_var), 8, var = svptrue_b8();) __ce(SIZE_IN_BITS(t_var), 16 , var = svptrue_b16();) \
18+ __ce (SIZE_IN_BITS(t_var), 32, var = svptrue_b32();) __ce(SIZE_IN_BITS(t_var), 64 , var = svptrue_b64();)
19+
20+ #define SVE_VLEN (var, t_var ) \
21+ size_t var; \
22+ \
23+ __ce (SIZE_IN_BITS(t_var), 8, var = svcntb();) __ce(SIZE_IN_BITS(t_var), 16 , var = svcnth();) \
24+ __ce (SIZE_IN_BITS(t_var), 32, var = svcntw();) __ce(SIZE_IN_BITS(t_var), 64 , var = svcntd();)
25+
26+ #define SVE_PREDICATE_WHILELT (var, t_var, arg1, arg2 ) \
27+ svbool_t var; \
28+ \
29+ __ce (SIZE_IN_BITS(t_var), 8, var = svwhilelt_b8(arg1, arg2);) \
30+ __ce (SIZE_IN_BITS(t_var), 16, var = svwhilelt_b16(arg1, arg2);) \
31+ __ce (SIZE_IN_BITS(t_var), 32, var = svwhilelt_b32(arg1, arg2);) \
32+ __ce (SIZE_IN_BITS(t_var), 64, var = svwhilelt_b64(arg1, arg2);)
33+
34+ namespace ov ::Extensions::Cpu::XARCH {
35+ static void cvt_copy (float * dst, ov::float16* src, size_t n) {
36+ auto src_ptr = reinterpret_cast <float16_t *>(src);
37+ auto pg_vl2 = svwhilelt_b16 (svcnth () / 2 , svcnth ());
38+ auto vlen = svcnth () / 2 ;
39+ auto pg_dst = svptrue_b32 ();
40+ size_t i = 0 ;
41+ for (; i + vlen <= n; i += vlen) {
42+ auto load_src = svld1_f16 (pg_vl2, src_ptr + i);
43+ auto src_interleave = svzip1_f16 (load_src, load_src);
44+ auto cvt_dst = svcvt_f32_f16_z (pg_dst, src_interleave);
45+ svst1 (pg_dst, dst + i, cvt_dst);
46+ }
47+ for (; i < n; i++) {
48+ dst[i] = src[i];
49+ }
50+ }
51+
52+ static void cvt_copy (ov::float16* dst, float * src, size_t n) {
53+ auto dst_ptr = reinterpret_cast <float16_t *>(dst);
54+ auto pg_src = svptrue_b32 ();
55+ auto pg_dst = svwhilelt_b16 (svcnth () / 2 , svcnth ());
56+ auto vlen = svcntw ();
57+ size_t i = 0 ;
58+ for (; i + vlen < n; i += vlen) {
59+ auto load_src = svld1_f32 (pg_src, src + i);
60+ auto cvt_dst = svcvt_f16_f32_z (pg_src, load_src);
61+ auto str_dst = svuzp1 (cvt_dst, cvt_dst);
62+ svst1_f16 (pg_dst, dst_ptr + i, str_dst);
63+ }
64+ for (; i < n; i++) {
65+ dst[i] = src[i];
66+ }
67+ }
68+ } // namespace ov::Extensions::Cpu::XARCH
0 commit comments