Skip to content

Commit 4f10c21

Browse files
author
Sanket Kale
committed
Added runtime guards and resolved CIfailures
1 parent 9439d8f commit 4f10c21

File tree

6 files changed

+51
-26
lines changed

6 files changed

+51
-26
lines changed

onnxruntime/core/mlas/lib/erf.cpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -285,27 +285,9 @@ MlasComputeFP16Erf(
285285
)
286286
{
287287
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
288-
289-
#if defined(MLAS_USE_SVE) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
290-
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
291-
MlasSveErfF16Kernel(
292-
reinterpret_cast<const _mlas_fp16_*>(Input),
293-
reinterpret_cast<_mlas_fp16_*>(Output),
294-
N
295-
);
296-
return;
297-
}
298-
#endif
299-
300-
#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
301-
MlasNeonErfF16Kernel(
302-
reinterpret_cast<const _mlas_fp16_*>(Input),
303-
reinterpret_cast<_mlas_fp16_*>(Output),
304-
N
305-
);
306-
return;
307-
#endif
308-
288+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
289+
GetMlasPlatform().ErfF16KernelRoutine(reinterpret_cast<const _mlas_fp16_*>(Input), reinterpret_cast<_mlas_fp16_*>(Output), N);
290+
#endif
309291
#else
310292
std::vector<float> input_fp32(N);
311293
std::vector<float> output_fp32(N);

onnxruntime/core/mlas/lib/erf_neon_fp16.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Module Name:
1414

1515
#include "erf_neon_fp16.h"
1616

17+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
18+
1719
// Helpers to safely convert between float and FP16-bit representation
1820
static float
1921
fp16_to_float(uint16_t h)
@@ -145,3 +147,4 @@ MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
145147
Output[i] = float_to_fp16(erf_approx);
146148
}
147149
}
150+
#endif

onnxruntime/core/mlas/lib/gelu.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ MlasComputeFP16Gelu(const MLAS_FP16* input,
2222
int64_t count,
2323
const std::string& algo)
2424
{
25-
#if defined(MLAS_USE_SVE) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
26-
MlasSveGeluF16Kernel(input, output, temp, count, algo);
27-
#elif defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
28-
MlasNeonGeluF16Kernel(input, output, temp, count, algo);
25+
#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS)
26+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
27+
GetMlasPlatform().GeluF16KernelRoutine(input, output, temp, count, algo);
28+
#endif
2929
#else
3030
(void)temp;
3131
for (int64_t i = 0; i < count; ++i) {

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,25 @@ void
610610
size_t N
611611
);
612612

613+
using _mlas_fp16_ = uint16_t;
614+
typedef
615+
void
616+
(MLASCALL MLAS_COMPUTE_ERF_FP16_KERNEL)(
617+
const _mlas_fp16_* Input,
618+
_mlas_fp16_* Output,
619+
size_t N
620+
);
621+
622+
typedef
623+
void
624+
(MLASCALL MLAS_COMPUTE_GELU_FP16_KERNEL)(
625+
const MLAS_FP16* Input,
626+
MLAS_FP16* Output,
627+
MLAS_FP16* Temp,
628+
int64_t N,
629+
const std::string& Algo
630+
);
631+
613632
typedef
614633
float
615634
(MLASCALL MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL)(
@@ -1057,6 +1076,8 @@ extern "C" {
10571076
MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel;
10581077
MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel;
10591078
MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel;
1079+
MLAS_COMPUTE_ERF_FP16_KERNEL MlasNeonErfF16Kernel;
1080+
MLAS_COMPUTE_GELU_FP16_KERNEL MlasNeonGeluF16Kernel;
10601081
#if defined(MLAS_TARGET_AMD64)
10611082
MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel;
10621083
MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel;
@@ -1410,6 +1431,10 @@ struct MLAS_PLATFORM {
14101431
MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel;
14111432
MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel;
14121433
MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel;
1434+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
1435+
MLAS_COMPUTE_ERF_FP16_KERNEL* ErfF16KernelRoutine;
1436+
MLAS_COMPUTE_GELU_FP16_KERNEL* GeluF16KernelRoutine;
1437+
#endif
14131438
#endif
14141439
#if defined(MLAS_TARGET_AMD64)
14151440
MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine;

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ Module Name:
1919
#ifdef MLAS_USE_SVE
2020
#include "sve/mlasi_sve.h"
2121
#endif
22+
#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
23+
#include "erf_neon_fp16.h"
24+
#include "gelu.h"
25+
#endif
2226
#if defined(USE_KLEIDIAI)
2327
#include "kleidiai/mlasi_kleidiai.h"
2428
#endif
@@ -635,6 +639,17 @@ Return Value:
635639
this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel;
636640
this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel;
637641
}
642+
643+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)
644+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
645+
this->ErfF16KernelRoutine = MlasSveErfF16Kernel;
646+
this->GeluF16KernelRoutine = MlasSveGeluF16Kernel;
647+
}
648+
else{
649+
this->ErfF16KernelRoutine = MlasNeonErfF16Kernel;
650+
this->GeluF16KernelRoutine = MlasNeonGeluF16Kernel;
651+
}
652+
#endif
638653
#endif
639654

640655
//

onnxruntime/core/providers/cpu/math/element_wise_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,7 +2040,7 @@ Status Erf<MLFloat16>::Compute(OpKernelContext* context) const {
20402040
const int64_t count = std::min(length_per_task, elem_count - start);
20412041
const MLFloat16* p_input = input_data + start;
20422042
MLFloat16* p_output = output_data + start;
2043-
MlasComputeFP16Erf(p_input, p_output, count);
2043+
MlasComputeFP16Erf(p_input, p_output, static_cast<size_t>(count));
20442044
},
20452045
0);
20462046

0 commit comments

Comments
 (0)