Skip to content

Commit 98d5401

Browse files
akote123Sanket Kale
authored andcommitted
Enable Gelu Fp16
Seperate platform dependant code
1 parent 05b3d81 commit 98d5401

File tree

13 files changed

+948
-12
lines changed

13 files changed

+948
-12
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ function(setup_mlas_source_for_windows)
116116
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
117117
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
118118
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
119+
${MLAS_SRC_DIR}/erf_neon_fp16.h
120+
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
119121
)
120122

121123
set(mlas_platform_preprocess_srcs
@@ -479,13 +481,17 @@ else()
479481
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
480482
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
481483
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
484+
${MLAS_SRC_DIR}/erf_neon_fp16.h
485+
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
482486
)
483487

484488
# Conditionally add the SVE implementation if compiler supports it
485489
if (onnxruntime_USE_SVE)
486490
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
487491
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
492+
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp)
488493
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
494+
set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
489495
list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
490496
endif()
491497

@@ -522,6 +528,7 @@ else()
522528
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
523529
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
524530
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
531+
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
525532
)
526533
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
527534
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -538,6 +545,7 @@ else()
538545
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
539546
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
540547
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
548+
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
541549
endif()
542550

543551
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,14 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS)
182182
set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ")
183183
endif()
184184

185-
target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT})
185+
if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64")
186+
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16)
187+
endif()
188+
target_include_directories(onnxruntime_providers PRIVATE
189+
${ONNXRUNTIME_ROOT}
190+
${ONNXRUNTIME_ROOT}/core/mlas/inc
191+
)
192+
186193
onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen)
187194
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
188195

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
erf_neon_fp16.cpp
8+
9+
Abstract:
10+
11+
This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.
12+
13+
--*/
14+
15+
#include "erf_neon_fp16.h"
16+
17+
// Helpers to safely convert between float and FP16-bit representation
18+
static float
19+
fp16_to_float(uint16_t h)
20+
{
21+
__fp16 tmp;
22+
memcpy(&tmp, &h, sizeof(h));
23+
return (float)tmp;
24+
}
25+
26+
static uint16_t
27+
float_to_fp16(float f)
28+
{
29+
__fp16 tmp = (__fp16)f;
30+
uint16_t h;
31+
memcpy(&h, &tmp, sizeof(h));
32+
return h;
33+
}
34+
35+
static inline MLAS_FLOAT16X8
36+
exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
37+
{
38+
const float16_t a0 = 6.0f;
39+
MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0);
40+
x = MlasMinimumFloat16(x, max_x);
41+
42+
const float16_t c0 = 1.330f;
43+
const float16_t c1 = -0.390f;
44+
const float16_t c2 = 0.0288f;
45+
46+
const float16_t d0 = 1.338f;
47+
const float16_t d1 = 0.848f;
48+
const float16_t d2 = 0.467f;
49+
50+
MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0);
51+
MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1);
52+
MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2);
53+
54+
MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0);
55+
MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1);
56+
MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2);
57+
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x);
58+
MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v);
59+
num = MlasMultiplyAddFloat16(c2v, x2, num);
60+
MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v);
61+
den = MlasMultiplyAddFloat16(d2v, x2, den);
62+
MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den);
63+
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
64+
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
65+
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip);
66+
return result;
67+
}
68+
69+
void
70+
MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
71+
{
72+
const float16_t p = 0.328f;
73+
const float16_t a1 = 0.2505f;
74+
const float16_t a2 = -0.2881f;
75+
const float16_t a3 = 1.4102f;
76+
const float16_t a4 = -1.423f;
77+
const float16_t a5 = 1.0547f;
78+
79+
MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p);
80+
MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1);
81+
MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2);
82+
MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3);
83+
MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4);
84+
MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5);
85+
86+
constexpr float16_t one_fp16 = 1.0f;
87+
constexpr float16_t neg_one_fp16 = -1.0f;
88+
constexpr float16_t zero_fp16 = 0.0f;
89+
constexpr float16_t four_fp16 = 4.0f;
90+
91+
MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16);
92+
MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16);
93+
MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16);
94+
MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16);
95+
96+
size_t i = 0;
97+
for (; i + 8 <= N; i += 8) {
98+
MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&Input[i]);
99+
MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero);
100+
MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone);
101+
MLAS_FLOAT16X8 absx = MlasAbsFloat16(x);
102+
MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth);
103+
MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth);
104+
MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone);
105+
MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom);
106+
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
107+
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
108+
MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t);
109+
MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t);
110+
MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t);
111+
MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t);
112+
MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t);
113+
poly = MlasMultiplyAddFloat16(va2, t2, poly);
114+
poly = MlasMultiplyAddFloat16(va3, t3, poly);
115+
poly = MlasMultiplyAddFloat16(va4, t4, poly);
116+
poly = MlasMultiplyAddFloat16(va5, t5, poly);
117+
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped);
118+
MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2);
119+
MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2);
120+
MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp);
121+
MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term);
122+
erf_approx = MlasMinimumFloat16(erf_approx, vone);
123+
erf_approx = MlasMaximumFloat16(erf_approx, vneg_one);
124+
MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign);
125+
MlasStoreFloat16x8(&Output[i], result);
126+
}
127+
128+
for (; i < N; i++) {
129+
float x = fp16_to_float(Input[i]);
130+
float sign = (x < 0) ? -1.0f : 1.0f;
131+
float absx = fabsf(x);
132+
133+
if (absx > 4.0f) {
134+
Output[i] = float_to_fp16(sign);
135+
continue;
136+
}
137+
138+
float t = 1.0f / (1.0f + p * absx);
139+
float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t;
140+
float exp_neg_x2 = expf(-absx * absx);
141+
float erf_approx = sign * (1.0f - poly * exp_neg_x2);
142+
if (erf_approx > 1.0f) erf_approx = 1.0f;
143+
if (erf_approx < -1.0f) erf_approx = -1.0f;
144+
145+
Output[i] = float_to_fp16(erf_approx);
146+
}
147+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
erf_neon_fp16.h
8+
9+
Abstract:
10+
11+
This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.
12+
13+
--*/
14+
15+
#pragma once
16+
17+
#include <arm_neon.h>
18+
19+
#include "mlasi.h"
20+
#include "fp16_common.h"
21+
#include "softmax_kernel_neon.h"
22+
23+
using _mlas_fp16_ = uint16_t;
24+
void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);

onnxruntime/core/mlas/lib/fp16_common.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ MLAS_FORCEINLINE
5454
MLAS_FLOAT16X8
5555
MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); }
5656

57+
MLAS_FORCEINLINE
58+
MLAS_FLOAT16X8
59+
MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); }
60+
5761
MLAS_FORCEINLINE
5862
MLAS_FLOAT16X4
5963
MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); }
@@ -78,6 +82,10 @@ MLAS_FORCEINLINE
7882
MLAS_FLOAT16X8
7983
MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); }
8084

85+
MLAS_FORCEINLINE
86+
MLAS_FLOAT16X8
87+
MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); }
88+
8189
MLAS_FORCEINLINE
8290
MLAS_FLOAT16X4
8391
MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); }
@@ -115,6 +123,13 @@ MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector)
115123
vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector));
116124
}
117125

126+
MLAS_FORCEINLINE
127+
void
128+
MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector)
129+
{
130+
vst1q_f16(Buffer, Vector);
131+
}
132+
118133
MLAS_FORCEINLINE
119134
void
120135
MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
@@ -579,4 +594,39 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector)
579594
return vshl_n_s16(Vector, ShiftCount);
580595
}
581596

597+
MLAS_FORCEINLINE
598+
MLAS_FLOAT16X8
599+
MlasReciprocalSqrtFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
600+
{
601+
return vrecpsq_f16(Vector1, Vector2);
602+
}
603+
604+
MLAS_FORCEINLINE
605+
MLAS_FLOAT16X8
606+
MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector)
607+
{
608+
return vrecpeq_f16(Vector);
609+
}
610+
611+
MLAS_FORCEINLINE
612+
MLAS_UINT16X8
613+
MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
614+
{
615+
return vcltq_f16(Vector1, Vector2);
616+
}
617+
618+
MLAS_FORCEINLINE
619+
MLAS_FLOAT16X8
620+
MlasAbsFloat16(MLAS_FLOAT16X8 Vector)
621+
{
622+
return vabsq_f16(Vector);
623+
}
624+
625+
MLAS_FORCEINLINE
626+
MLAS_FLOAT16X8
627+
MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
628+
{
629+
return vbslq_f16(Vector, Vector1, Vector2);
630+
}
631+
582632
#endif // fp16 vector intrinsic supported

0 commit comments

Comments
 (0)