Skip to content

Commit cc2625d

Browse files
committed
Enable Gelu Fp16
Seperate platform dependant code
1 parent 52a38a5 commit cc2625d

File tree

13 files changed

+948
-12
lines changed

13 files changed

+948
-12
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ function(setup_mlas_source_for_windows)
113113
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
114114
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
115115
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
116+
${MLAS_SRC_DIR}/erf_neon_fp16.h
117+
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
116118
)
117119

118120
set(mlas_platform_preprocess_srcs
@@ -460,13 +462,17 @@ else()
460462
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
461463
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
462464
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
465+
${MLAS_SRC_DIR}/erf_neon_fp16.h
466+
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
463467
)
464468

465469
# Conditionally add the SVE implementation if compiler supports it
466470
if (onnxruntime_USE_SVE)
467471
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
468472
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
473+
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp)
469474
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
475+
set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
470476
list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
471477
endif()
472478

@@ -502,6 +508,7 @@ else()
502508
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
503509
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
504510
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
511+
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
505512
)
506513
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
507514
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -517,6 +524,7 @@ else()
517524
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
518525
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
519526
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
527+
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
520528
endif()
521529

522530
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,14 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS)
172172
set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ")
173173
endif()
174174

175-
target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT})
175+
if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64")
176+
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16)
177+
endif()
178+
target_include_directories(onnxruntime_providers PRIVATE
179+
${ONNXRUNTIME_ROOT}
180+
${ONNXRUNTIME_ROOT}/core/mlas/inc
181+
)
182+
176183
onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen)
177184
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
178185

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
erf_neon_fp16.cpp
8+
9+
Abstract:
10+
11+
This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.
12+
13+
--*/
14+
15+
#include "erf_neon_fp16.h"
16+
17+
// Helpers to safely convert between float and FP16-bit representation
18+
static float
19+
fp16_to_float(uint16_t h)
20+
{
21+
__fp16 tmp;
22+
memcpy(&tmp, &h, sizeof(h));
23+
return (float)tmp;
24+
}
25+
26+
static uint16_t
27+
float_to_fp16(float f)
28+
{
29+
__fp16 tmp = (__fp16)f;
30+
uint16_t h;
31+
memcpy(&h, &tmp, sizeof(h));
32+
return h;
33+
}
34+
35+
static inline MLAS_FLOAT16X8
36+
exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
37+
{
38+
const float16_t a0 = 6.0f;
39+
MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0);
40+
x = MlasMinimumFloat16(x, max_x);
41+
42+
const float16_t c0 = 1.330f;
43+
const float16_t c1 = -0.390f;
44+
const float16_t c2 = 0.0288f;
45+
46+
const float16_t d0 = 1.338f;
47+
const float16_t d1 = 0.848f;
48+
const float16_t d2 = 0.467f;
49+
50+
MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0);
51+
MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1);
52+
MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2);
53+
54+
MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0);
55+
MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1);
56+
MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2);
57+
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x);
58+
MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v);
59+
num = MlasMultiplyAddFloat16(c2v, x2, num);
60+
MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v);
61+
den = MlasMultiplyAddFloat16(d2v, x2, den);
62+
MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den);
63+
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
64+
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
65+
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip);
66+
return result;
67+
}
68+
69+
void
70+
MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
71+
{
72+
const float16_t p = 0.328f;
73+
const float16_t a1 = 0.2505f;
74+
const float16_t a2 = -0.2881f;
75+
const float16_t a3 = 1.4102f;
76+
const float16_t a4 = -1.423f;
77+
const float16_t a5 = 1.0547f;
78+
79+
MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p);
80+
MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1);
81+
MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2);
82+
MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3);
83+
MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4);
84+
MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5);
85+
86+
constexpr float16_t one_fp16 = 1.0f;
87+
constexpr float16_t neg_one_fp16 = -1.0f;
88+
constexpr float16_t zero_fp16 = 0.0f;
89+
constexpr float16_t four_fp16 = 4.0f;
90+
91+
MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16);
92+
MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16);
93+
MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16);
94+
MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16);
95+
96+
size_t i = 0;
97+
for (; i + 8 <= N; i += 8) {
98+
MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&Input[i]);
99+
MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero);
100+
MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone);
101+
MLAS_FLOAT16X8 absx = MlasAbsFloat16(x);
102+
MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth);
103+
MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth);
104+
MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone);
105+
MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom);
106+
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
107+
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
108+
MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t);
109+
MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t);
110+
MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t);
111+
MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t);
112+
MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t);
113+
poly = MlasMultiplyAddFloat16(va2, t2, poly);
114+
poly = MlasMultiplyAddFloat16(va3, t3, poly);
115+
poly = MlasMultiplyAddFloat16(va4, t4, poly);
116+
poly = MlasMultiplyAddFloat16(va5, t5, poly);
117+
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped);
118+
MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2);
119+
MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2);
120+
MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp);
121+
MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term);
122+
erf_approx = MlasMinimumFloat16(erf_approx, vone);
123+
erf_approx = MlasMaximumFloat16(erf_approx, vneg_one);
124+
MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign);
125+
MlasStoreFloat16x8(&Output[i], result);
126+
}
127+
128+
for (; i < N; i++) {
129+
float x = fp16_to_float(Input[i]);
130+
float sign = (x < 0) ? -1.0f : 1.0f;
131+
float absx = fabsf(x);
132+
133+
if (absx > 4.0f) {
134+
Output[i] = float_to_fp16(sign);
135+
continue;
136+
}
137+
138+
float t = 1.0f / (1.0f + p * absx);
139+
float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t;
140+
float exp_neg_x2 = expf(-absx * absx);
141+
float erf_approx = sign * (1.0f - poly * exp_neg_x2);
142+
if (erf_approx > 1.0f) erf_approx = 1.0f;
143+
if (erf_approx < -1.0f) erf_approx = -1.0f;
144+
145+
Output[i] = float_to_fp16(erf_approx);
146+
}
147+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
erf_neon_fp16.h
8+
9+
Abstract:
10+
11+
This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.
12+
13+
--*/
14+
15+
#pragma once
16+
17+
#include <arm_neon.h>
18+
19+
#include "mlasi.h"
20+
#include "fp16_common.h"
21+
#include "softmax_kernel_neon.h"
22+
23+
using _mlas_fp16_ = uint16_t;
24+
void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);

onnxruntime/core/mlas/lib/fp16_common.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ MLAS_FORCEINLINE
5454
MLAS_FLOAT16X8
5555
MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); }
5656

57+
MLAS_FORCEINLINE
58+
MLAS_FLOAT16X8
59+
MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); }
60+
5761
MLAS_FORCEINLINE
5862
MLAS_FLOAT16X4
5963
MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); }
@@ -78,6 +82,10 @@ MLAS_FORCEINLINE
7882
MLAS_FLOAT16X8
7983
MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); }
8084

85+
MLAS_FORCEINLINE
86+
MLAS_FLOAT16X8
87+
MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); }
88+
8189
MLAS_FORCEINLINE
8290
MLAS_FLOAT16X4
8391
MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); }
@@ -115,6 +123,13 @@ MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector)
115123
vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector));
116124
}
117125

126+
MLAS_FORCEINLINE
127+
void
128+
MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector)
129+
{
130+
vst1q_f16(Buffer, Vector);
131+
}
132+
118133
MLAS_FORCEINLINE
119134
void
120135
MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
@@ -579,4 +594,39 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector)
579594
return vshl_n_s16(Vector, ShiftCount);
580595
}
581596

597+
MLAS_FORCEINLINE
598+
MLAS_FLOAT16X8
599+
MlasReciprocalSqrtFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
600+
{
601+
return vrecpsq_f16(Vector1, Vector2);
602+
}
603+
604+
MLAS_FORCEINLINE
605+
MLAS_FLOAT16X8
606+
MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector)
607+
{
608+
return vrecpeq_f16(Vector);
609+
}
610+
611+
MLAS_FORCEINLINE
612+
MLAS_UINT16X8
613+
MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
614+
{
615+
return vcltq_f16(Vector1, Vector2);
616+
}
617+
618+
MLAS_FORCEINLINE
619+
MLAS_FLOAT16X8
620+
MlasAbsFloat16(MLAS_FLOAT16X8 Vector)
621+
{
622+
return vabsq_f16(Vector);
623+
}
624+
625+
MLAS_FORCEINLINE
626+
MLAS_FLOAT16X8
627+
MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
628+
{
629+
return vbslq_f16(Vector, Vector1, Vector2);
630+
}
631+
582632
#endif // fp16 vector intrinsic supported

0 commit comments

Comments
 (0)