@@ -33,16 +33,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
33
#include "common.h"
34
34
#include <arm_neon.h>
35
35
36
- #if (defined(__GNUC__ ) && __GNUC__ >= 13 )
37
- #define BF16_TO_FP32 (bf16 ) ((float)(bf16))
38
- #else
39
- static inline float bf16_to_fp32 (bfloat16_t bf16 ) {
40
- uint32_t fp32 = (uint32_t )(* ((u_int16_t * )(& bf16 ))) << 16 ;
41
- return * ((float * )& fp32 );
42
- }
43
- #define BF16_TO_FP32 (bf16 ) bf16_to_fp32(bf16)
44
- #endif
45
-
46
36
static void beta_op (float * x , BLASLONG n , FLOAT beta ) {
47
37
if (beta == 0 ) {
48
38
memset (x , 0 , n * sizeof (float ));
@@ -268,24 +258,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
268
258
}
269
259
270
260
if (rest_m ) {
271
- x0 = alpha * BF16_TO_FP32 (x_ptr [0 ]);
272
- x1 = alpha * BF16_TO_FP32 (x_ptr [1 ]);
273
- x2 = alpha * BF16_TO_FP32 (x_ptr [2 ]);
274
- x3 = alpha * BF16_TO_FP32 (x_ptr [3 ]);
275
- x4 = alpha * BF16_TO_FP32 (x_ptr [4 ]);
276
- x5 = alpha * BF16_TO_FP32 (x_ptr [5 ]);
277
- x6 = alpha * BF16_TO_FP32 (x_ptr [6 ]);
278
- x7 = alpha * BF16_TO_FP32 (x_ptr [7 ]);
261
+ x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
262
+ x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
263
+ x2 = alpha * vcvtah_f32_bf16 (x_ptr [2 ]);
264
+ x3 = alpha * vcvtah_f32_bf16 (x_ptr [3 ]);
265
+ x4 = alpha * vcvtah_f32_bf16 (x_ptr [4 ]);
266
+ x5 = alpha * vcvtah_f32_bf16 (x_ptr [5 ]);
267
+ x6 = alpha * vcvtah_f32_bf16 (x_ptr [6 ]);
268
+ x7 = alpha * vcvtah_f32_bf16 (x_ptr [7 ]);
279
269
280
270
for (BLASLONG j = 0 ; j < rest_m ; j ++ ) {
281
- y_ptr [j ] += x0 * BF16_TO_FP32 (a_ptr0 [j ]);
282
- y_ptr [j ] += x1 * BF16_TO_FP32 (a_ptr1 [j ]);
283
- y_ptr [j ] += x2 * BF16_TO_FP32 (a_ptr2 [j ]);
284
- y_ptr [j ] += x3 * BF16_TO_FP32 (a_ptr3 [j ]);
285
- y_ptr [j ] += x4 * BF16_TO_FP32 (a_ptr4 [j ]);
286
- y_ptr [j ] += x5 * BF16_TO_FP32 (a_ptr5 [j ]);
287
- y_ptr [j ] += x6 * BF16_TO_FP32 (a_ptr6 [j ]);
288
- y_ptr [j ] += x7 * BF16_TO_FP32 (a_ptr7 [j ]);
271
+ y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
272
+ y_ptr [j ] += x1 * vcvtah_f32_bf16 (a_ptr1 [j ]);
273
+ y_ptr [j ] += x2 * vcvtah_f32_bf16 (a_ptr2 [j ]);
274
+ y_ptr [j ] += x3 * vcvtah_f32_bf16 (a_ptr3 [j ]);
275
+ y_ptr [j ] += x4 * vcvtah_f32_bf16 (a_ptr4 [j ]);
276
+ y_ptr [j ] += x5 * vcvtah_f32_bf16 (a_ptr5 [j ]);
277
+ y_ptr [j ] += x6 * vcvtah_f32_bf16 (a_ptr6 [j ]);
278
+ y_ptr [j ] += x7 * vcvtah_f32_bf16 (a_ptr7 [j ]);
289
279
}
290
280
}
291
281
@@ -384,16 +374,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
384
374
}
385
375
386
376
if (rest_m ) {
387
- x0 = alpha * BF16_TO_FP32 (x_ptr [0 ]);
388
- x1 = alpha * BF16_TO_FP32 (x_ptr [1 ]);
389
- x2 = alpha * BF16_TO_FP32 (x_ptr [2 ]);
390
- x3 = alpha * BF16_TO_FP32 (x_ptr [3 ]);
377
+ x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
378
+ x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
379
+ x2 = alpha * vcvtah_f32_bf16 (x_ptr [2 ]);
380
+ x3 = alpha * vcvtah_f32_bf16 (x_ptr [3 ]);
391
381
392
382
for (BLASLONG j = 0 ; j < rest_m ; j ++ ) {
393
- y_ptr [j ] += x0 * BF16_TO_FP32 (a_ptr0 [j ]);
394
- y_ptr [j ] += x1 * BF16_TO_FP32 (a_ptr1 [j ]);
395
- y_ptr [j ] += x2 * BF16_TO_FP32 (a_ptr2 [j ]);
396
- y_ptr [j ] += x3 * BF16_TO_FP32 (a_ptr3 [j ]);
383
+ y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
384
+ y_ptr [j ] += x1 * vcvtah_f32_bf16 (a_ptr1 [j ]);
385
+ y_ptr [j ] += x2 * vcvtah_f32_bf16 (a_ptr2 [j ]);
386
+ y_ptr [j ] += x3 * vcvtah_f32_bf16 (a_ptr3 [j ]);
397
387
}
398
388
}
399
389
@@ -480,13 +470,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
480
470
}
481
471
482
472
if (m & 2 ) {
483
- x0 = alpha * (BF16_TO_FP32 (x_ptr [0 ]));
484
- x1 = alpha * (BF16_TO_FP32 (x_ptr [1 ]));
473
+ x0 = alpha * (vcvtah_f32_bf16 (x_ptr [0 ]));
474
+ x1 = alpha * (vcvtah_f32_bf16 (x_ptr [1 ]));
485
475
486
- y_ptr [0 ] += x0 * BF16_TO_FP32 (a_ptr0 [0 ]);
487
- y_ptr [0 ] += x1 * BF16_TO_FP32 (a_ptr1 [0 ]);
488
- y_ptr [1 ] += x0 * BF16_TO_FP32 (a_ptr0 [1 ]);
489
- y_ptr [1 ] += x1 * BF16_TO_FP32 (a_ptr1 [1 ]);
476
+ y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
477
+ y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
478
+ y_ptr [1 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [1 ]);
479
+ y_ptr [1 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [1 ]);
490
480
491
481
a_ptr0 += 2 ;
492
482
a_ptr1 += 2 ;
@@ -495,23 +485,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
495
485
}
496
486
497
487
if (m & 1 ) {
498
- x0 = alpha * BF16_TO_FP32 (x_ptr [0 ]);
499
- x1 = alpha * BF16_TO_FP32 (x_ptr [1 ]);
488
+ x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
489
+ x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
500
490
501
- y_ptr [0 ] += x0 * BF16_TO_FP32 (a_ptr0 [0 ]);
502
- y_ptr [0 ] += x1 * BF16_TO_FP32 (a_ptr1 [0 ]);
491
+ y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
492
+ y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
503
493
}
504
494
505
495
x_ptr += 2 ;
506
496
}
507
497
508
498
if (n & 1 ) {
509
- x0 = BF16_TO_FP32 (x_ptr [0 ]) * alpha ;
499
+ x0 = vcvtah_f32_bf16 (x_ptr [0 ]) * alpha ;
510
500
y_ptr = y ;
511
501
a_ptr0 = a_ptr ;
512
502
513
503
for (j = 0 ; j < m ; j ++ ) {
514
- y_ptr [j ] += x0 * BF16_TO_FP32 (a_ptr0 [j ]);
504
+ y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
515
505
}
516
506
}
517
507
@@ -525,10 +515,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
525
515
}
526
516
527
517
for (j = 0 ; j < n ; j ++ ) {
528
- x0 = alpha * BF16_TO_FP32 (* x_ptr );
518
+ x0 = alpha * vcvtah_f32_bf16 (* x_ptr );
529
519
iy = 0 ;
530
520
for (i = 0 ; i < m ; i ++ ) {
531
- y [iy ] += x0 * BF16_TO_FP32 (a_ptr [i ]);
521
+ y [iy ] += x0 * vcvtah_f32_bf16 (a_ptr [i ]);
532
522
iy += incy ;
533
523
}
534
524
0 commit comments