Skip to content

Commit 3ccdca0

Browse files
committed
Resolve Review Comments
1 parent cc2625d commit 3ccdca0

File tree

14 files changed

+307
-185
lines changed

14 files changed

+307
-185
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
5252
${MLAS_SRC_DIR}/rotary_embedding.cpp
5353
${MLAS_SRC_DIR}/softmax.h
5454
${MLAS_SRC_DIR}/saturation_check.cpp
55+
${MLAS_SRC_DIR}/gelu.cpp
5556
)
5657

5758
target_sources(onnxruntime_mlas PRIVATE
@@ -115,6 +116,7 @@ function(setup_mlas_source_for_windows)
115116
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
116117
${MLAS_SRC_DIR}/erf_neon_fp16.h
117118
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
119+
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
118120
)
119121

120122
set(mlas_platform_preprocess_srcs
@@ -464,15 +466,16 @@ else()
464466
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
465467
${MLAS_SRC_DIR}/erf_neon_fp16.h
466468
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
469+
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
467470
)
468471

469472
# Conditionally add the SVE implementation if compiler supports it
470473
if (onnxruntime_USE_SVE)
471474
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
472475
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
473-
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp)
476+
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp)
474477
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
475-
set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
478+
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
476479
list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
477480
endif()
478481

@@ -509,6 +512,7 @@ else()
509512
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
510513
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
511514
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
515+
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
512516
)
513517
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
514518
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -525,6 +529,7 @@ else()
525529
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
526530
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
527531
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
532+
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
528533
endif()
529534

530535
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,7 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS)
172172
set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ")
173173
endif()
174174

175-
if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64")
176-
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16)
177-
endif()
178-
target_include_directories(onnxruntime_providers PRIVATE
179-
${ONNXRUNTIME_ROOT}
180-
${ONNXRUNTIME_ROOT}/core/mlas/inc
181-
)
182-
175+
target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT})
183176
onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen)
184177
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
185178

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,3 +2126,21 @@ void
21262126
MLASCALL
21272127
MlasGemmBatchPackUseKleidi(bool enable);
21282128
#endif
2129+
2130+
void
2131+
MLASCALL
2132+
MlasComputeFP16Erf(
2133+
const MLAS_FP16* Input,
2134+
MLAS_FP16* Output,
2135+
size_t N
2136+
);
2137+
2138+
void
2139+
MLASCALL
2140+
MlasComputeFP16Gelu(
2141+
const MLAS_FP16* input,
2142+
MLAS_FP16* output,
2143+
MLAS_FP16* temp,
2144+
int64_t count,
2145+
const std::string& algo
2146+
);

onnxruntime/core/mlas/lib/erf.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ Module Name:
2222
--*/
2323

2424
#include "mlasi.h"
25+
26+
#ifdef MLAS_USE_SVE
27+
#include "sve/mlasi_sve.h"
28+
#endif
29+
30+
#if defined(MLAS_NEON_INTRINSICS)
31+
#include "erf_neon_fp16.h"
32+
#endif
33+
2534
//
2635
// Bundles the constants for use by kernels written in assembly.
2736
//
@@ -266,3 +275,43 @@ Return Value:
266275
MlasErfKernel(Input, Output, N);
267276
#endif
268277
}
278+
279+
void
280+
MLASCALL
281+
MlasComputeFP16Erf(
282+
const MLAS_FP16* Input,
283+
MLAS_FP16* Output,
284+
size_t N
285+
)
286+
{
287+
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
288+
289+
#if defined(MLAS_USE_SVE)
290+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
291+
MlasSveErfF16Kernel(
292+
reinterpret_cast<const _mlas_fp16_*>(Input),
293+
reinterpret_cast<_mlas_fp16_*>(Output),
294+
N
295+
);
296+
return;
297+
}
298+
#endif
299+
300+
#if defined(MLAS_NEON_INTRINSICS)
301+
MlasNeonErfF16Kernel(
302+
reinterpret_cast<const _mlas_fp16_*>(Input),
303+
reinterpret_cast<_mlas_fp16_*>(Output),
304+
N
305+
);
306+
return;
307+
#endif
308+
309+
#else
310+
std::vector<float> input_fp32(N);
311+
std::vector<float> output_fp32(N);
312+
313+
MlasConvertHalfToFloatBuffer(Input, input_fp32.data(), N);
314+
MlasComputeErf(input_fp32.data(), output_fp32.data(), N);
315+
MlasConvertFloatToHalfBuffer(output_fp32.data(), Output, N);
316+
#endif
317+
}

onnxruntime/core/mlas/lib/erf_neon_fp16.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
6767
}
6868

6969
void
70-
MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
70+
MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
7171
{
7272
const float16_t p = 0.328f;
7373
const float16_t a1 = 0.2505f;
@@ -144,4 +144,4 @@ MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
144144

145145
Output[i] = float_to_fp16(erf_approx);
146146
}
147-
}
147+
}

onnxruntime/core/mlas/lib/erf_neon_fp16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Module Name:
2121
#include "softmax_kernel_neon.h"
2222

2323
using _mlas_fp16_ = uint16_t;
24-
void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);
24+
void MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);

onnxruntime/core/mlas/lib/gelu.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
Gelu.cpp
8+
9+
Abstract:
10+
11+
This module contains Gelu helper functions .
12+
13+
--*/
14+
15+
#include "gelu.h"
16+
17+
18+
void
19+
MLASCALL
20+
MlasComputeFP16Gelu(const MLAS_FP16* input,
21+
MLAS_FP16* output,
22+
MLAS_FP16* temp,
23+
int64_t count,
24+
const std::string& algo)
25+
{
26+
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
27+
28+
bool done = false;
29+
30+
#if defined(MLAS_USE_SVE)
31+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
32+
MlasSveGeluF16Kernel(input, output, temp, count, algo);
33+
done = true;
34+
}
35+
#endif
36+
37+
#if defined(MLAS_NEON_INTRINSICS)
38+
if (!done) {
39+
MlasNeonGeluF16Kernel(input, output, temp, count, algo);
40+
done = true;
41+
}
42+
#endif
43+
44+
#else
45+
46+
(void)temp;
47+
for (int64_t i = 0; i < count; ++i) {
48+
float x = static_cast<float>(input[i]);
49+
float gelu_val;
50+
51+
if (algo == "tanh") {
52+
// GELU approximation (tanh)
53+
const float B = 0.7978845608f;
54+
const float C = 0.044715f * B;
55+
float tanh_arg = x * (B + C * x * x);
56+
float tanh_res = std::tanh(tanh_arg);
57+
gelu_val = 0.5f * x * (1.0f + tanh_res);
58+
} else {
59+
// GELU exact (erf)
60+
gelu_val = 0.5f * x *
61+
(1.0f + std::erf(x * static_cast<float>(M_SQRT1_2)));
62+
}
63+
64+
output[i] = MLAS_FP16(gelu_val);
65+
}
66+
67+
#endif
68+
}

onnxruntime/core/mlas/lib/gelu.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
Gelu.cpp
8+
9+
Abstract:
10+
11+
This module contains Gelu helper functions .
12+
13+
--*/
14+
15+
#include "fp16_common.h"
16+
#if defined(MLAS_NEON_INTRINSICS)
17+
#include "erf_neon_fp16.h"
18+
#endif
19+
20+
#ifdef MLAS_USE_SVE
21+
#include "sve/mlasi_sve.h"
22+
#endif
23+
24+
void
25+
MLASCALL
26+
MlasNeonGeluF16Kernel(
27+
const MLAS_FP16* input,
28+
MLAS_FP16* output,
29+
MLAS_FP16* temp,
30+
int64_t count,
31+
const std::string& algo
32+
);
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
Gelu.cpp
8+
9+
Abstract:
10+
11+
This module contains Gelu helper functions .
12+
13+
--*/
14+
#include "gelu.h"
15+
16+
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
17+
18+
void
19+
MLASCALL
20+
MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo)
21+
{
22+
const float16_t v_half1 = 0.5f;
23+
const float16_t v_one1 = 1.0f;
24+
const float16_t v_sqrt1_21 = static_cast<float>(M_SQRT1_2);
25+
const float16_t v_B1 = 0.7978845608028654f;
26+
const float16_t v_C1 = 0.035677408136300125f;
27+
const float16_t c1 = 5.0f;
28+
const float16_t c2 = -5.0f;
29+
const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1);
30+
const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1);
31+
const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21);
32+
const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1);
33+
const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1);
34+
35+
int64_t i = 0;
36+
37+
if (algo == "tanh") {
38+
// Preprocess input into temp[] for tanh
39+
for (; i + 7 < count; i += 8) {
40+
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
41+
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x);
42+
MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2
43+
MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2)
44+
tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1)));
45+
MlasStoref16Float16x8(reinterpret_cast<float16_t*>(temp + i), tanh_arg);
46+
}
47+
48+
// Tail
49+
for (; i < count; ++i) {
50+
float x = static_cast<float>(input[i]);
51+
float inner = x * (0.7979f + 0.03568f * x * x);
52+
inner = std::max(-5.0f, std::min(5.0f, inner));
53+
temp[i] = static_cast<MLAS_FP16>(inner);
54+
}
55+
56+
// Tanh processing
57+
MlasComputeTanh<MLAS_FP16>(temp, temp, count);
58+
59+
} else if (algo == "none") {
60+
// Preprocess input into temp[] for erf
61+
for (i = 0; i + 7 < count; i += 8) {
62+
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
63+
MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2);
64+
MlasStoref16Float16x8(reinterpret_cast<float16_t*>(temp + i), scaled);
65+
}
66+
67+
// Tail
68+
for (; i < count; ++i) {
69+
float x = static_cast<float>(input[i]);
70+
temp[i] = static_cast<MLAS_FP16>(x * 0.70710678f);
71+
}
72+
73+
// Erf processing
74+
MlasNeonErfF16Kernel(reinterpret_cast<const _mlas_fp16_*>(temp), reinterpret_cast<_mlas_fp16_*>(temp), count);
75+
}
76+
77+
// Final GELU output = 0.5 * x * (1 + tanh|erf)
78+
i = 0;
79+
for (; i + 7 < count; i += 8) {
80+
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
81+
MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(temp + i));
82+
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t)));
83+
MlasStoref16Float16x8(reinterpret_cast<float16_t*>(output + i), result);
84+
}
85+
86+
for (; i < count; ++i) {
87+
float x = static_cast<float>(input[i]);
88+
float t = static_cast<float>(temp[i]);
89+
float gelu = 0.5f * x * (1.0f + t);
90+
output[i] = static_cast<MLAS_FP16>(gelu);
91+
}
92+
}
93+
#endif

0 commit comments

Comments
 (0)