Skip to content

Commit 937cd82

Browse files
akote123Sanket Kale
authored andcommitted
Resolve Review Comments
1 parent 98d5401 commit 937cd82

File tree

14 files changed

+318
-185
lines changed

14 files changed

+318
-185
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
5454
${MLAS_SRC_DIR}/rotary_embedding.cpp
5555
${MLAS_SRC_DIR}/softmax.h
5656
${MLAS_SRC_DIR}/saturation_check.cpp
57+
${MLAS_SRC_DIR}/gelu.cpp
5758
)
5859

5960
target_sources(onnxruntime_mlas PRIVATE
@@ -118,6 +119,7 @@ function(setup_mlas_source_for_windows)
118119
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
119120
${MLAS_SRC_DIR}/erf_neon_fp16.h
120121
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
122+
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
121123
)
122124

123125
set(mlas_platform_preprocess_srcs
@@ -483,15 +485,16 @@ else()
483485
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
484486
${MLAS_SRC_DIR}/erf_neon_fp16.h
485487
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
488+
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
486489
)
487490

488491
# Conditionally add the SVE implementation if compiler supports it
489492
if (onnxruntime_USE_SVE)
490493
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
491494
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
492-
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp)
495+
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp)
493496
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
494-
set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
497+
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
495498
list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
496499
endif()
497500

@@ -529,6 +532,7 @@ else()
529532
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
530533
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
531534
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
535+
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
532536
)
533537
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
534538
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -546,6 +550,7 @@ else()
546550
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
547551
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
548552
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
553+
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
549554
endif()
550555

551556
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,7 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS)
182182
set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ")
183183
endif()
184184

185-
if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64")
186-
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16)
187-
endif()
188-
target_include_directories(onnxruntime_providers PRIVATE
189-
${ONNXRUNTIME_ROOT}
190-
${ONNXRUNTIME_ROOT}/core/mlas/inc
191-
)
192-
185+
target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT})
193186
onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen)
194187
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
195188

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,3 +2127,32 @@ MlasFlashAttention(
21272127
MlasFlashAttentionThreadedArgs* args,
21282128
MLAS_THREADPOOL* ThreadPool
21292129
);
2130+
2131+
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
2132+
/**
2133+
* @brief Function to override the packing mechanism decision if kleidi ai is included
2134+
* @param enable enable kleidiai packing (allow or disallow depending on true/false)
2135+
* @return
2136+
*/
2137+
void
2138+
MLASCALL
2139+
MlasGemmBatchPackUseKleidi(bool enable);
2140+
#endif
2141+
2142+
void
2143+
MLASCALL
2144+
MlasComputeFP16Erf(
2145+
const MLAS_FP16* Input,
2146+
MLAS_FP16* Output,
2147+
size_t N
2148+
);
2149+
2150+
void
2151+
MLASCALL
2152+
MlasComputeFP16Gelu(
2153+
const MLAS_FP16* input,
2154+
MLAS_FP16* output,
2155+
MLAS_FP16* temp,
2156+
int64_t count,
2157+
const std::string& algo
2158+
);

onnxruntime/core/mlas/lib/erf.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ Module Name:
2222
--*/
2323

2424
#include "mlasi.h"
25+
26+
#ifdef MLAS_USE_SVE
27+
#include "sve/mlasi_sve.h"
28+
#endif
29+
30+
#if defined(MLAS_NEON_INTRINSICS)
31+
#include "erf_neon_fp16.h"
32+
#endif
33+
2534
//
2635
// Bundles the constants for use by kernels written in assembly.
2736
//
@@ -266,3 +275,43 @@ Return Value:
266275
MlasErfKernel(Input, Output, N);
267276
#endif
268277
}
278+
279+
void
280+
MLASCALL
281+
MlasComputeFP16Erf(
282+
const MLAS_FP16* Input,
283+
MLAS_FP16* Output,
284+
size_t N
285+
)
286+
{
287+
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
288+
289+
#if defined(MLAS_USE_SVE)
290+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
291+
MlasSveErfF16Kernel(
292+
reinterpret_cast<const _mlas_fp16_*>(Input),
293+
reinterpret_cast<_mlas_fp16_*>(Output),
294+
N
295+
);
296+
return;
297+
}
298+
#endif
299+
300+
#if defined(MLAS_NEON_INTRINSICS)
301+
MlasNeonErfF16Kernel(
302+
reinterpret_cast<const _mlas_fp16_*>(Input),
303+
reinterpret_cast<_mlas_fp16_*>(Output),
304+
N
305+
);
306+
return;
307+
#endif
308+
309+
#else
310+
std::vector<float> input_fp32(N);
311+
std::vector<float> output_fp32(N);
312+
313+
MlasConvertHalfToFloatBuffer(Input, input_fp32.data(), N);
314+
MlasComputeErf(input_fp32.data(), output_fp32.data(), N);
315+
MlasConvertFloatToHalfBuffer(output_fp32.data(), Output, N);
316+
#endif
317+
}

onnxruntime/core/mlas/lib/erf_neon_fp16.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
6767
}
6868

6969
void
70-
MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
70+
MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
7171
{
7272
const float16_t p = 0.328f;
7373
const float16_t a1 = 0.2505f;
@@ -144,4 +144,4 @@ MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
144144

145145
Output[i] = float_to_fp16(erf_approx);
146146
}
147-
}
147+
}

onnxruntime/core/mlas/lib/erf_neon_fp16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Module Name:
2121
#include "softmax_kernel_neon.h"
2222

2323
using _mlas_fp16_ = uint16_t;
24-
void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);
24+
void MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);

onnxruntime/core/mlas/lib/gelu.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
Gelu.cpp
8+
9+
Abstract:
10+
11+
This module contains Gelu helper functions .
12+
13+
--*/
14+
15+
#include "gelu.h"
16+
17+
18+
void
19+
MLASCALL
20+
MlasComputeFP16Gelu(const MLAS_FP16* input,
21+
MLAS_FP16* output,
22+
MLAS_FP16* temp,
23+
int64_t count,
24+
const std::string& algo)
25+
{
26+
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
27+
28+
bool done = false;
29+
30+
#if defined(MLAS_USE_SVE)
31+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
32+
MlasSveGeluF16Kernel(input, output, temp, count, algo);
33+
done = true;
34+
}
35+
#endif
36+
37+
#if defined(MLAS_NEON_INTRINSICS)
38+
if (!done) {
39+
MlasNeonGeluF16Kernel(input, output, temp, count, algo);
40+
done = true;
41+
}
42+
#endif
43+
44+
#else
45+
46+
(void)temp;
47+
for (int64_t i = 0; i < count; ++i) {
48+
float x = static_cast<float>(input[i]);
49+
float gelu_val;
50+
51+
if (algo == "tanh") {
52+
// GELU approximation (tanh)
53+
const float B = 0.7978845608f;
54+
const float C = 0.044715f * B;
55+
float tanh_arg = x * (B + C * x * x);
56+
float tanh_res = std::tanh(tanh_arg);
57+
gelu_val = 0.5f * x * (1.0f + tanh_res);
58+
} else {
59+
// GELU exact (erf)
60+
gelu_val = 0.5f * x *
61+
(1.0f + std::erf(x * static_cast<float>(M_SQRT1_2)));
62+
}
63+
64+
output[i] = MLAS_FP16(gelu_val);
65+
}
66+
67+
#endif
68+
}

onnxruntime/core/mlas/lib/gelu.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
Gelu.cpp
8+
9+
Abstract:
10+
11+
This module contains Gelu helper functions .
12+
13+
--*/
14+
15+
#include "fp16_common.h"
16+
#if defined(MLAS_NEON_INTRINSICS)
17+
#include "erf_neon_fp16.h"
18+
#endif
19+
20+
#ifdef MLAS_USE_SVE
21+
#include "sve/mlasi_sve.h"
22+
#endif
23+
24+
void
25+
MLASCALL
26+
MlasNeonGeluF16Kernel(
27+
const MLAS_FP16* input,
28+
MLAS_FP16* output,
29+
MLAS_FP16* temp,
30+
int64_t count,
31+
const std::string& algo
32+
);
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*++
2+
3+
Copyright 2025 FUJITSU LIMITED
4+
5+
Module Name:
6+
7+
Gelu.cpp
8+
9+
Abstract:
10+
11+
This module contains Gelu helper functions .
12+
13+
--*/
14+
#include "gelu.h"
15+
16+
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
17+
18+
void
19+
MLASCALL
20+
MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo)
21+
{
22+
const float16_t v_half1 = 0.5f;
23+
const float16_t v_one1 = 1.0f;
24+
const float16_t v_sqrt1_21 = static_cast<float>(M_SQRT1_2);
25+
const float16_t v_B1 = 0.7978845608028654f;
26+
const float16_t v_C1 = 0.035677408136300125f;
27+
const float16_t c1 = 5.0f;
28+
const float16_t c2 = -5.0f;
29+
const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1);
30+
const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1);
31+
const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21);
32+
const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1);
33+
const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1);
34+
35+
int64_t i = 0;
36+
37+
if (algo == "tanh") {
38+
// Preprocess input into temp[] for tanh
39+
for (; i + 7 < count; i += 8) {
40+
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
41+
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x);
42+
MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2
43+
MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2)
44+
tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1)));
45+
MlasStoref16Float16x8(reinterpret_cast<float16_t*>(temp + i), tanh_arg);
46+
}
47+
48+
// Tail
49+
for (; i < count; ++i) {
50+
float x = static_cast<float>(input[i]);
51+
float inner = x * (0.7979f + 0.03568f * x * x);
52+
inner = std::max(-5.0f, std::min(5.0f, inner));
53+
temp[i] = static_cast<MLAS_FP16>(inner);
54+
}
55+
56+
// Tanh processing
57+
MlasComputeTanh<MLAS_FP16>(temp, temp, count);
58+
59+
} else if (algo == "none") {
60+
// Preprocess input into temp[] for erf
61+
for (i = 0; i + 7 < count; i += 8) {
62+
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
63+
MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2);
64+
MlasStoref16Float16x8(reinterpret_cast<float16_t*>(temp + i), scaled);
65+
}
66+
67+
// Tail
68+
for (; i < count; ++i) {
69+
float x = static_cast<float>(input[i]);
70+
temp[i] = static_cast<MLAS_FP16>(x * 0.70710678f);
71+
}
72+
73+
// Erf processing
74+
MlasNeonErfF16Kernel(reinterpret_cast<const _mlas_fp16_*>(temp), reinterpret_cast<_mlas_fp16_*>(temp), count);
75+
}
76+
77+
// Final GELU output = 0.5 * x * (1 + tanh|erf)
78+
i = 0;
79+
for (; i + 7 < count; i += 8) {
80+
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
81+
MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(temp + i));
82+
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t)));
83+
MlasStoref16Float16x8(reinterpret_cast<float16_t*>(output + i), result);
84+
}
85+
86+
for (; i < count; ++i) {
87+
float x = static_cast<float>(input[i]);
88+
float t = static_cast<float>(temp[i]);
89+
float gelu = 0.5f * x * (1.0f + t);
90+
output[i] = static_cast<MLAS_FP16>(gelu);
91+
}
92+
}
93+
#endif

0 commit comments

Comments
 (0)