Skip to content

Commit 2f77855

Browse files
authored
Merge pull request #5181 from taoye9/change_sbgemn_cast_bf16
replace customize bf16_to_fp32 with arm neon vcvtah_f32_bf16
2 parents 66e0f1e + 4c00099 commit 2f77855

File tree

1 file changed

+38
-48
lines changed

1 file changed

+38
-48
lines changed

kernel/arm64/sbgemv_n_neon.c

Lines changed: 38 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333
#include "common.h"
3434
#include <arm_neon.h>
3535

36-
#if (defined(__GNUC__) && __GNUC__ >= 13)
37-
#define BF16_TO_FP32(bf16) ((float)(bf16))
38-
#else
39-
static inline float bf16_to_fp32(bfloat16_t bf16) {
40-
uint32_t fp32 = (uint32_t)(*((u_int16_t*)(&bf16))) << 16;
41-
return *((float*)&fp32);
42-
}
43-
#define BF16_TO_FP32(bf16) bf16_to_fp32(bf16)
44-
#endif
45-
4636
static void beta_op(float *x, BLASLONG n, FLOAT beta) {
4737
if (beta == 0) {
4838
memset(x, 0, n * sizeof(float));
@@ -268,24 +258,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
268258
}
269259

270260
if (rest_m) {
271-
x0 = alpha * BF16_TO_FP32(x_ptr[0]);
272-
x1 = alpha * BF16_TO_FP32(x_ptr[1]);
273-
x2 = alpha * BF16_TO_FP32(x_ptr[2]);
274-
x3 = alpha * BF16_TO_FP32(x_ptr[3]);
275-
x4 = alpha * BF16_TO_FP32(x_ptr[4]);
276-
x5 = alpha * BF16_TO_FP32(x_ptr[5]);
277-
x6 = alpha * BF16_TO_FP32(x_ptr[6]);
278-
x7 = alpha * BF16_TO_FP32(x_ptr[7]);
261+
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
262+
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
263+
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
264+
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
265+
x4 = alpha * vcvtah_f32_bf16(x_ptr[4]);
266+
x5 = alpha * vcvtah_f32_bf16(x_ptr[5]);
267+
x6 = alpha * vcvtah_f32_bf16(x_ptr[6]);
268+
x7 = alpha * vcvtah_f32_bf16(x_ptr[7]);
279269

280270
for (BLASLONG j = 0; j < rest_m; j++) {
281-
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]);
282-
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]);
283-
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]);
284-
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]);
285-
y_ptr[j] += x4 * BF16_TO_FP32(a_ptr4[j]);
286-
y_ptr[j] += x5 * BF16_TO_FP32(a_ptr5[j]);
287-
y_ptr[j] += x6 * BF16_TO_FP32(a_ptr6[j]);
288-
y_ptr[j] += x7 * BF16_TO_FP32(a_ptr7[j]);
271+
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
272+
y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]);
273+
y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]);
274+
y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]);
275+
y_ptr[j] += x4 * vcvtah_f32_bf16(a_ptr4[j]);
276+
y_ptr[j] += x5 * vcvtah_f32_bf16(a_ptr5[j]);
277+
y_ptr[j] += x6 * vcvtah_f32_bf16(a_ptr6[j]);
278+
y_ptr[j] += x7 * vcvtah_f32_bf16(a_ptr7[j]);
289279
}
290280
}
291281

@@ -384,16 +374,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
384374
}
385375

386376
if (rest_m) {
387-
x0 = alpha * BF16_TO_FP32(x_ptr[0]);
388-
x1 = alpha * BF16_TO_FP32(x_ptr[1]);
389-
x2 = alpha * BF16_TO_FP32(x_ptr[2]);
390-
x3 = alpha * BF16_TO_FP32(x_ptr[3]);
377+
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
378+
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
379+
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
380+
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
391381

392382
for (BLASLONG j = 0; j < rest_m; j++) {
393-
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]);
394-
y_ptr[j] += x1 * BF16_TO_FP32(a_ptr1[j]);
395-
y_ptr[j] += x2 * BF16_TO_FP32(a_ptr2[j]);
396-
y_ptr[j] += x3 * BF16_TO_FP32(a_ptr3[j]);
383+
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
384+
y_ptr[j] += x1 * vcvtah_f32_bf16(a_ptr1[j]);
385+
y_ptr[j] += x2 * vcvtah_f32_bf16(a_ptr2[j]);
386+
y_ptr[j] += x3 * vcvtah_f32_bf16(a_ptr3[j]);
397387
}
398388
}
399389

@@ -480,13 +470,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
480470
}
481471

482472
if (m & 2) {
483-
x0 = alpha * (BF16_TO_FP32(x_ptr[0]));
484-
x1 = alpha * (BF16_TO_FP32(x_ptr[1]));
473+
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0]));
474+
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1]));
485475

486-
y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]);
487-
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]);
488-
y_ptr[1] += x0 * BF16_TO_FP32(a_ptr0[1]);
489-
y_ptr[1] += x1 * BF16_TO_FP32(a_ptr1[1]);
476+
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
477+
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
478+
y_ptr[1] += x0 * vcvtah_f32_bf16(a_ptr0[1]);
479+
y_ptr[1] += x1 * vcvtah_f32_bf16(a_ptr1[1]);
490480

491481
a_ptr0 += 2;
492482
a_ptr1 += 2;
@@ -495,23 +485,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
495485
}
496486

497487
if (m & 1) {
498-
x0 = alpha * BF16_TO_FP32(x_ptr[0]);
499-
x1 = alpha * BF16_TO_FP32(x_ptr[1]);
488+
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
489+
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
500490

501-
y_ptr[0] += x0 * BF16_TO_FP32(a_ptr0[0]);
502-
y_ptr[0] += x1 * BF16_TO_FP32(a_ptr1[0]);
491+
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
492+
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
503493
}
504494

505495
x_ptr += 2;
506496
}
507497

508498
if (n & 1) {
509-
x0 = BF16_TO_FP32(x_ptr[0]) * alpha;
499+
x0 = vcvtah_f32_bf16(x_ptr[0]) * alpha;
510500
y_ptr = y;
511501
a_ptr0 = a_ptr;
512502

513503
for (j = 0; j < m; j++) {
514-
y_ptr[j] += x0 * BF16_TO_FP32(a_ptr0[j]);
504+
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
515505
}
516506
}
517507

@@ -525,10 +515,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
525515
}
526516

527517
for (j = 0; j < n; j++) {
528-
x0 = alpha * BF16_TO_FP32(*x_ptr);
518+
x0 = alpha * vcvtah_f32_bf16(*x_ptr);
529519
iy = 0;
530520
for (i = 0; i < m; i++) {
531-
y[iy] += x0 * BF16_TO_FP32(a_ptr[i]);
521+
y[iy] += x0 * vcvtah_f32_bf16(a_ptr[i]);
532522
iy += incy;
533523
}
534524

0 commit comments

Comments
 (0)