1212#include " core/providers/cpu/element_wise_ranged_transform.h"
1313#include " core/providers/cpu/tensor/gelu.h"
1414
15+ #include < cstddef>
16+ #include < cstdlib>
17+ #include < memory>
18+
19+ #if defined(_WIN32)
20+ #include < malloc.h>
21+ #endif
22+
23+ inline void * AlignedAlloc (size_t alignment, size_t size) {
24+ #if defined(_WIN32)
25+ return _aligned_malloc (size, alignment);
26+ #else
27+ // std::aligned_alloc requires size to be a multiple of alignment
28+ return std::aligned_alloc (alignment, size);
29+ #endif
30+ }
31+
32+ inline void AlignedFree (void * p) {
33+ #if defined(_WIN32)
34+ _aligned_free (p);
35+ #else
36+ std::free (p);
37+ #endif
38+ }
39+
1540using onnxruntime::narrow;
1641using namespace onnxruntime ::common;
1742
@@ -128,16 +153,24 @@ Status Gelu<MLFloat16>::Compute(OpKernelContext* context) const {
128153
129154 // Alignment and buffer size for aligned_alloc
130155 constexpr size_t alignment = 64 ;
156+
131157 size_t buffer_size = elem_count * sizeof (MLFloat16);
132- size_t aligned_size = ((buffer_size + alignment - 1 ) / alignment) * alignment;
133- auto deleter = [](MLFloat16* p) { std::free (p); } ;
134- std::unique_ptr<MLFloat16, decltype (deleter)> temp_fp16_aligned (
135- reinterpret_cast <MLFloat16*>( std::aligned_alloc ( alignment, aligned_size)),
136- deleter);
137- if (temp_fp16_aligned == nullptr ) {
138- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Failed to allocate aligned temporary buffer." );
158+ size_t aligned_size =
159+ ((buffer_size + alignment - 1 ) / alignment) * alignment ;
160+
161+ void * raw = AlignedAlloc ( alignment, aligned_size);
162+ if (!raw) {
163+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL,
164+ " Failed to allocate aligned temporary buffer." );
139165 }
140166
167+ auto deleter = [](MLFloat16* p) {
168+ AlignedFree (p);
169+ };
170+
171+ std::unique_ptr<MLFloat16, decltype (deleter)> temp_fp16_aligned (
172+ static_cast <MLFloat16*>(raw), deleter);
173+
141174 concurrency::ThreadPool::TryBatchParallelFor (
142175 tp,
143176 static_cast <int32_t >(task_count),
@@ -147,7 +180,7 @@ Status Gelu<MLFloat16>::Compute(OpKernelContext* context) const {
147180 MLFloat16* p_output = output_data + start;
148181 int64_t count = std::min (length_per_task, elem_count - start);
149182 MLFloat16* p_temp = temp_fp16_aligned.get () + start;
150- MlasComputeFP16Gelu (p_input, p_output, p_temp, count, approximation_algorithm_);
183+ MlasComputeFP16Gelu (p_input, p_output, p_temp, count, approximation_algorithm_);
151184
152185 },
153186 0 );
0 commit comments