Skip to content

Commit cf6d83f

Browse files
author
Sanket Kale
committed
Resolved Copilot comments
1 parent 937cd82 commit cf6d83f

File tree

9 files changed

+74
-57
lines changed

9 files changed

+74
-57
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,8 @@ else()
549549
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
550550
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
551551
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
552-
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
553-
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
552+
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
553+
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
554554
endif()
555555

556556
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

onnxruntime/core/mlas/lib/erf_neon_fp16.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
6060
MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v);
6161
den = MlasMultiplyAddFloat16(d2v, x2, den);
6262
MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den);
63-
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
64-
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
63+
recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip));
64+
recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip));
6565
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip);
6666
return result;
6767
}
@@ -103,8 +103,8 @@ MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
103103
MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth);
104104
MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone);
105105
MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom);
106-
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
107-
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
106+
t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t));
107+
t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t));
108108
MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t);
109109
MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t);
110110
MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t);

onnxruntime/core/mlas/lib/fp16_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector)
596596

597597
MLAS_FORCEINLINE
598598
MLAS_FLOAT16X8
599-
MlasReciprocalSqrtFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
599+
MlasReciprocalStepFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
600600
{
601601
return vrecpsq_f16(Vector1, Vector2);
602602
}

onnxruntime/core/mlas/lib/gelu.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ Module Name:
88
99
Abstract:
1010
11-
This module contains Gelu helper functions .
11+
This module contains Gelu helper functions.
1212
1313
--*/
1414

1515
#include "gelu.h"
1616

17-
1817
void
1918
MLASCALL
2019
MlasComputeFP16Gelu(const MLAS_FP16* input,
@@ -23,26 +22,11 @@ MlasComputeFP16Gelu(const MLAS_FP16* input,
2322
int64_t count,
2423
const std::string& algo)
2524
{
26-
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
27-
28-
bool done = false;
29-
3025
#if defined(MLAS_USE_SVE)
31-
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
3226
MlasSveGeluF16Kernel(input, output, temp, count, algo);
33-
done = true;
34-
}
35-
#endif
36-
37-
#if defined(MLAS_NEON_INTRINSICS)
38-
if (!done) {
27+
#elif defined(MLAS_NEON_INTRINSICS)
3928
MlasNeonGeluF16Kernel(input, output, temp, count, algo);
40-
done = true;
41-
}
42-
#endif
43-
4429
#else
45-
4630
(void)temp;
4731
for (int64_t i = 0; i < count; ++i) {
4832
float x = static_cast<float>(input[i]);
@@ -63,6 +47,5 @@ MlasComputeFP16Gelu(const MLAS_FP16* input,
6347

6448
output[i] = MLAS_FP16(gelu_val);
6549
}
66-
6750
#endif
6851
}

onnxruntime/core/mlas/lib/gelu.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,17 @@ Copyright 2025 FUJITSU LIMITED
44
55
Module Name:
66
7-
Gelu.cpp
7+
gelu.h
88
99
Abstract:
1010
11-
This module contains Gelu helper functions .
11+
This module contains Gelu helper functions .
1212
1313
--*/
1414

1515
#include "fp16_common.h"
1616
#if defined(MLAS_NEON_INTRINSICS)
1717
#include "erf_neon_fp16.h"
18-
#endif
19-
20-
#ifdef MLAS_USE_SVE
21-
#include "sve/mlasi_sve.h"
22-
#endif
2318

2419
void
2520
MLASCALL
@@ -29,4 +24,10 @@ MlasNeonGeluF16Kernel(
2924
MLAS_FP16* temp,
3025
int64_t count,
3126
const std::string& algo
32-
);
27+
);
28+
29+
#endif
30+
31+
#ifdef MLAS_USE_SVE
32+
#include "sve/mlasi_sve.h"
33+
#endif

onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ Copyright 2025 FUJITSU LIMITED
44
55
Module Name:
66
7-
Gelu.cpp
7+
gelu_neon_fp16.cpp
88
99
Abstract:
1010
1111
This module contains Gelu helper functions .
1212
1313
--*/
1414
#include "gelu.h"
15-
15+
#include <cmath>
1616
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1717

1818
void
@@ -48,15 +48,15 @@ MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp
4848
// Tail
4949
for (; i < count; ++i) {
5050
float x = static_cast<float>(input[i]);
51-
float inner = x * (0.7979f + 0.03568f * x * x);
51+
float inner = x * (0.7978845608028654f + 0.035677408136300125f * x * x);
5252
inner = std::max(-5.0f, std::min(5.0f, inner));
5353
temp[i] = static_cast<MLAS_FP16>(inner);
5454
}
5555

5656
// Tanh processing
5757
MlasComputeTanh<MLAS_FP16>(temp, temp, count);
5858

59-
} else if (algo == "none") {
59+
} else{
6060
// Preprocess input into temp[] for erf
6161
for (i = 0; i + 7 < count; i += 8) {
6262
MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast<const float16_t*>(input + i));
@@ -67,7 +67,7 @@ MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp
6767
// Tail
6868
for (; i < count; ++i) {
6969
float x = static_cast<float>(input[i]);
70-
temp[i] = static_cast<MLAS_FP16>(x * 0.70710678f);
70+
temp[i] = static_cast<MLAS_FP16>(x * static_cast<float>(M_SQRT1_2));
7171
}
7272

7373
// Erf processing

onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp,
190190
const __fp16 r1 = 0.5f;
191191
const __fp16 r2 = 1.0f;
192192
const __fp16 r3 = static_cast<float>(M_SQRT1_2);
193-
const __fp16 r4 = 0.7979f;
194-
const __fp16 r5 = 0.03568f;
193+
const __fp16 r4 = 0.7978845608028654f;
194+
const __fp16 r5 = 0.035677408136300125f;
195195

196196
const MLAS_SVFLOAT16 v_half = MlasSveBroadcastfloat16(r1);
197197
const MLAS_SVFLOAT16 v_one = MlasSveBroadcastfloat16(r2);
@@ -203,7 +203,7 @@ MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp,
203203
const __fp16 c2 = 5.0f;
204204
if (algo == "tanh") {
205205
int64_t i = 0;
206-
while (i < (count)) {
206+
while (i < count) {
207207
svbool_t pg = MlasSveSelPredictefloat16(i, count);
208208
MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]);
209209
MLAS_SVFLOAT16 v_x2 = MlasSveMulfloat16(pg, v_x, v_x);
@@ -225,7 +225,7 @@ MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp,
225225
MlasSveStoreF16(pg, &output[j], v_result);
226226
j += svcnth();
227227
}
228-
} else if (algo == "none") {
228+
} else {
229229
int64_t i = 0;
230230
while (i < (count)) {
231231
svbool_t pg = MlasSveSelPredictefloat16(i, count);

onnxruntime/core/mlas/lib/sve/mlasi_sve.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ MlasSveTanhF16Kernel(
5656
void
5757
MLASCALL
5858
MlasSveGeluF16Kernel(
59-
const MLAS_FP16* input,
60-
MLAS_FP16* output,
61-
MLAS_FP16* temp,
62-
int64_t count,
63-
const std::string& algo
59+
const MLAS_FP16* Input,
60+
MLAS_FP16* Output,
61+
MLAS_FP16* Temp,
62+
int64_t N,
63+
const std::string& Algo
6464
);
65-
// function decarations
65+
// function declarations
6666
MLAS_FORCEINLINE
6767
MLAS_SVFLOAT32
6868
MlasSveComputeExpVector(

onnxruntime/core/providers/cpu/tensor/gelu.cc

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,31 @@
1212
#include "core/providers/cpu/element_wise_ranged_transform.h"
1313
#include "core/providers/cpu/tensor/gelu.h"
1414

15+
#include <cstddef>
16+
#include <cstdlib>
17+
#include <memory>
18+
19+
#if defined(_WIN32)
20+
#include <malloc.h>
21+
#endif
22+
23+
inline void* AlignedAlloc(size_t alignment, size_t size) {
24+
#if defined(_WIN32)
25+
return _aligned_malloc(size, alignment);
26+
#else
27+
// std::aligned_alloc requires size to be a multiple of alignment
28+
return std::aligned_alloc(alignment, size);
29+
#endif
30+
}
31+
32+
inline void AlignedFree(void* p) {
33+
#if defined(_WIN32)
34+
_aligned_free(p);
35+
#else
36+
std::free(p);
37+
#endif
38+
}
39+
1540
using onnxruntime::narrow;
1641
using namespace onnxruntime::common;
1742

@@ -128,16 +153,24 @@ Status Gelu<MLFloat16>::Compute(OpKernelContext* context) const {
128153

129154
// Alignment and buffer size for aligned_alloc
130155
constexpr size_t alignment = 64;
156+
131157
size_t buffer_size = elem_count * sizeof(MLFloat16);
132-
size_t aligned_size = ((buffer_size + alignment - 1) / alignment) * alignment;
133-
auto deleter = [](MLFloat16* p) { std::free(p); };
134-
std::unique_ptr<MLFloat16, decltype(deleter)> temp_fp16_aligned(
135-
reinterpret_cast<MLFloat16*>(std::aligned_alloc(alignment, aligned_size)),
136-
deleter);
137-
if (temp_fp16_aligned == nullptr) {
138-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate aligned temporary buffer.");
158+
size_t aligned_size =
159+
((buffer_size + alignment - 1) / alignment) * alignment;
160+
161+
void* raw = AlignedAlloc(alignment, aligned_size);
162+
if (!raw) {
163+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
164+
"Failed to allocate aligned temporary buffer.");
139165
}
140166

167+
auto deleter = [](MLFloat16* p) {
168+
AlignedFree(p);
169+
};
170+
171+
std::unique_ptr<MLFloat16, decltype(deleter)> temp_fp16_aligned(
172+
static_cast<MLFloat16*>(raw), deleter);
173+
141174
concurrency::ThreadPool::TryBatchParallelFor(
142175
tp,
143176
static_cast<int32_t>(task_count),
@@ -147,7 +180,7 @@ Status Gelu<MLFloat16>::Compute(OpKernelContext* context) const {
147180
MLFloat16* p_output = output_data + start;
148181
int64_t count = std::min(length_per_task, elem_count - start);
149182
MLFloat16* p_temp = temp_fp16_aligned.get() + start;
150-
MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_);
183+
MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_);
151184

152185
},
153186
0);

0 commit comments

Comments
 (0)