3030#elif defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE)
3131# include " arm_sve.h"
3232# include " nodes/kernels/aarch64/brgemm_kernel.hpp"
33- # include " nodes/kernels/aarch64/pa_kernels .hpp"
33+ # include " nodes/kernels/aarch64/sve_utils .hpp"
3434# include " nodes/kernels/kai/kleidi_kernel.hpp"
3535#endif
3636
@@ -39,7 +39,7 @@ namespace ov::Extensions::Cpu::XARCH {
3939using namespace ov ;
4040using namespace ov ::intel_cpu;
4141
42- // currently depends on brgemm which only support x64
42+ // currently depends on brgemm which only support x64 or ARM SVE
4343#if defined(OPENVINO_ARCH_X86_64) || (defined(OPENVINO_ARCH_ARM64) && defined(HAVE_SVE))
4444
4545# if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
@@ -72,16 +72,27 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
7272 }
7373# elif defined(HAVE_SVE)
7474 if constexpr (std::is_same<TA, TB>::value) {
75- SVE_PREDICATE ( pg_dst, TA)
76- SVE_VLEN ( vlen, TA)
75+ auto pg_dst = sve_predicate< sizeof ( TA)>();
76+ auto vlen = sve_vlen< sizeof ( TA)>();
7777 for (; i + vlen <= n; i += vlen) {
7878 auto vb = svld1 (pg_dst, src + i);
7979 svst1 (pg_dst, dst + i, vb);
8080 }
81- SVE_PREDICATE_WHILELT ( pgt, TA, i, n)
81+ auto pgt = sve_predicate< TA, sizeof (TA)>( i, n);
8282 auto vb = svld1 (pg_dst, src + i);
8383 svst1 (pg_dst, dst + i, vb);
8484 return ;
85+ } else if constexpr (std::is_same<TA, float >::value && std::is_same<TB, ov::float16>::value) {
86+ auto src_ptr = reinterpret_cast <float16_t *>(src);
87+ auto pg_vl2 = svwhilelt_b16 (svcnth () / 2 , svcnth ());
88+ auto vlen = svcnth () / 2 ;
89+ auto pg_dst = svptrue_b32 ();
90+ for (; i + vlen <= n; i += vlen) {
91+ auto load_src = svld1_f16 (pg_vl2, src_ptr + i);
92+ auto src_interleave = svzip1_f16 (load_src, load_src);
93+ auto cvt_dst = svcvt_f32_f16_z (pg_dst, src_interleave);
94+ svst1 (pg_dst, dst + i, cvt_dst);
95+ }
8596 }
8697# endif
8798 for (; i < n; i++) {
@@ -1715,7 +1726,7 @@ struct MHAHelper {
17151726 void init_reorder_buffers (size_t batch, size_t kv_len_in_blocks) {
17161727 _qk_scratch_b.resize <DATA_TYPE>({batch, kv_len_in_blocks, Hk, _block_size * S});
17171728 if (AarchF16) {
1718- // Required to keep kv_cache continuous in mem, as kleidi do to support accumulation
1729+ // It is required to keep kv_cache continuous in mem, as kleidi do not support accumulation
17191730 _wv_scratch_b.resize <DATA_TYPE>({batch, Hk, kv_len_in_blocks, _block_size * rnd_up (SV, _block_size)});
17201731 } else {
17211732 _wv_scratch_b.resize <DATA_TYPE>({batch, kv_len_in_blocks, Hk, _block_size * rnd_up (SV, _block_size)});
@@ -1918,7 +1929,7 @@ struct MHAHelper {
19181929 auto _score_stride = _weight.stride_bytes (2 ) / 2 ;
19191930 for (size_t h = hq_beg; h < hq_end; h++) {
19201931 auto * q_ptr = query.ptr <DATA_TYPE>(h, q_start, 0 );
1921- float * c_ptr = _weight.ptr <float >(ithr, h, 0 , 0 );
1932+ float * c_ptr = _weight.ptr <float >(ithr, h - hq_beg , 0 , 0 );
19221933 // for each query block, loop through all key block
19231934 // for blocks:
19241935 // 1 0 0 0 ...
@@ -1947,8 +1958,8 @@ struct MHAHelper {
19471958 for (size_t m = q_start; m < q_end; m++) {
19481959 // apply softmax in f32 precision
19491960 auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1 );
1950- auto soft_in = _weight.ptr <float >(ithr, h, m - q_start);
1951- auto score = _weight.ptr <float >(ithr, h, m - q_start);
1961+ auto soft_in = _weight.ptr <float >(ithr, h - hq_beg , m - q_start);
1962+ auto score = _weight.ptr <float >(ithr, h - hq_beg , m - q_start);
19521963 PlainTensor f32_cvt;
19531964 if (q_is_xf16) {
19541965 f32_cvt.resize <float >({size_t {rnd_up (cur_kv_len, _block_size)}});
@@ -2007,7 +2018,7 @@ struct MHAHelper {
20072018 }
20082019
20092020 // reuse float buffer, need to use float to compute offset
2010- auto * w_ptr = reinterpret_cast <DATA_TYPE*>(_weight.ptr <float >(ithr, h, 0 , 0 ));
2021+ auto * w_ptr = reinterpret_cast <DATA_TYPE*>(_weight.ptr <float >(ithr, h - hq_beg , 0 , 0 ));
20112022 DATA_TYPE* out_ptr = output_emb.ptr <DATA_TYPE>(q_start, h * SV);
20122023 DATA_TYPE* v_ptr;
20132024 v_ptr = wv_scratch_b.ptr <DATA_TYPE>(hk, 0 );
0 commit comments