From 98d54016cbdc187bb0f6c443f23f9f98334b646d Mon Sep 17 00:00:00 2001 From: akote123 Date: Tue, 9 Dec 2025 17:46:55 +0530 Subject: [PATCH 1/7] Enable Gelu Fp16 Seperate platform dependant code --- cmake/onnxruntime_mlas.cmake | 8 + cmake/onnxruntime_providers_cpu.cmake | 9 +- onnxruntime/core/mlas/lib/erf_neon_fp16.cpp | 147 ++++++++++ onnxruntime/core/mlas/lib/erf_neon_fp16.h | 24 ++ onnxruntime/core/mlas/lib/fp16_common.h | 50 ++++ .../mlas/lib/sve/Elementwise_sve_fp16.cpp | 251 ++++++++++++++++++ onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h | 182 +++++++++++++ onnxruntime/core/mlas/lib/sve/mlasi_sve.h | 11 + onnxruntime/core/mlas/lib/tanh.cpp | 11 +- .../providers/cpu/cpu_execution_provider.cc | 8 +- .../providers/cpu/math/element_wise_ops.cc | 27 +- onnxruntime/core/providers/cpu/tensor/gelu.cc | 180 ++++++++++++- .../cpu/activation/activation_op_test.cc | 52 ++++ 13 files changed, 948 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/erf_neon_fp16.cpp create mode 100644 onnxruntime/core/mlas/lib/erf_neon_fp16.h create mode 100644 onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp create mode 100644 onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d7dcde945e6d7..d0efa2abe4a6e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -116,6 +116,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.h + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -479,13 +481,17 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.h + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp ) # Conditionally add the SVE implementation if compiler supports it if (onnxruntime_USE_SVE) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp) set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) endif() @@ -522,6 +528,7 @@ else() ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -538,6 +545,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f77a5dd78fcc5..79feaaec00b75 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -182,7 +182,14 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS) set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") endif() -target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT}) +if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64") +set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16) +endif() +target_include_directories(onnxruntime_providers PRIVATE + ${ONNXRUNTIME_ROOT} + ${ONNXRUNTIME_ROOT}/core/mlas/inc +) + onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen) add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp new file mode 100644 index 0000000000000..60c7695346a54 --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp @@ -0,0 +1,147 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + erf_neon_fp16.cpp + +Abstract: + + This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. + +--*/ + +#include "erf_neon_fp16.h" + +// Helpers to safely convert between float and FP16-bit representation +static float +fp16_to_float(uint16_t h) +{ + __fp16 tmp; + memcpy(&tmp, &h, sizeof(h)); + return (float)tmp; +} + +static uint16_t +float_to_fp16(float f) +{ + __fp16 tmp = (__fp16)f; + uint16_t h; + memcpy(&h, &tmp, sizeof(h)); + return h; +} + +static inline MLAS_FLOAT16X8 +exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) +{ + const float16_t a0 = 6.0f; + MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0); + x = MlasMinimumFloat16(x, max_x); + + const float16_t c0 = 1.330f; + const float16_t c1 = -0.390f; + const float16_t c2 = 0.0288f; + + const float16_t d0 = 1.338f; + const float16_t d1 = 0.848f; + const float16_t d2 = 0.467f; + + MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0); + MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1); + MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2); + + MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0); + MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1); + MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v); + num = MlasMultiplyAddFloat16(c2v, x2, num); + MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v); + den = MlasMultiplyAddFloat16(d2v, x2, den); + MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den); + recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip)); + recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip); + return result; +} + +void +MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) +{ + const float16_t p = 0.328f; + const float16_t a1 = 0.2505f; + const float16_t a2 = -0.2881f; + const float16_t a3 = 1.4102f; + const float16_t a4 = -1.423f; + const float16_t a5 = 1.0547f; + + MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p); + MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1); + MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2); + MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3); + MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4); + MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5); + + constexpr float16_t one_fp16 = 1.0f; + constexpr float16_t neg_one_fp16 = -1.0f; + constexpr float16_t zero_fp16 = 0.0f; + constexpr float16_t four_fp16 = 4.0f; + + MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16); + MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16); + MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16); + MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16); + + size_t i = 0; + for (; i + 8 <= N; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&Input[i]); + MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero); + MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone); + MLAS_FLOAT16X8 absx = MlasAbsFloat16(x); + MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth); + MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth); + MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone); + MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom); + t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t)); + t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t)); + MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t); + MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t); + MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t); + MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t); + MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t); + poly = MlasMultiplyAddFloat16(va2, t2, poly); + poly = MlasMultiplyAddFloat16(va3, t3, poly); + poly = MlasMultiplyAddFloat16(va4, t4, poly); + poly = MlasMultiplyAddFloat16(va5, t5, poly); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped); + MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2); + MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2); + MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp); + MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term); + erf_approx = MlasMinimumFloat16(erf_approx, vone); + erf_approx = MlasMaximumFloat16(erf_approx, vneg_one); + MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign); + MlasStoreFloat16x8(&Output[i], result); + } + + for (; i < N; i++) { + float x = fp16_to_float(Input[i]); + float sign = (x < 0) ? -1.0f : 1.0f; + float absx = fabsf(x); + + if (absx > 4.0f) { + Output[i] = float_to_fp16(sign); + continue; + } + + float t = 1.0f / (1.0f + p * absx); + float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t; + float exp_neg_x2 = expf(-absx * absx); + float erf_approx = sign * (1.0f - poly * exp_neg_x2); + if (erf_approx > 1.0f) erf_approx = 1.0f; + if (erf_approx < -1.0f) erf_approx = -1.0f; + + Output[i] = float_to_fp16(erf_approx); + } +} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.h b/onnxruntime/core/mlas/lib/erf_neon_fp16.h new file mode 100644 index 0000000000000..154ef0104baa3 --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.h @@ -0,0 +1,24 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + erf_neon_fp16.h + +Abstract: + + This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" +#include "fp16_common.h" +#include "softmax_kernel_neon.h" + +using _mlas_fp16_ = uint16_t; +void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N); diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index d4713cce5a176..71655f9b59905 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -54,6 +54,10 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); } + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); } @@ -78,6 +82,10 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); } + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } @@ -115,6 +123,13 @@ MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector)); } +MLAS_FORCEINLINE +void +MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector) +{ + vst1q_f16(Buffer, Vector); +} + MLAS_FORCEINLINE void MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) @@ -579,4 +594,39 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector) return vshl_n_s16(Vector, ShiftCount); } +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasReciprocalSqrtFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vrecpsq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector) +{ + return vrecpeq_f16(Vector); +} + +MLAS_FORCEINLINE +MLAS_UINT16X8 +MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vcltq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasAbsFloat16(MLAS_FLOAT16X8 Vector) +{ + return vabsq_f16(Vector); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vbslq_f16(Vector, Vector1, Vector2); +} + #endif // fp16 vector intrinsic supported diff --git a/onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp b/onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp new file mode 100644 index 0000000000000..2530c712a8866 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp @@ -0,0 +1,251 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + Elementwise_sve_fp16.cpp + +Abstract: + + This module contains the SVE Elementwise functions . + +--*/ +#include "mlas_sve_fp16.h" + +struct MlasTanhConstants_fp16_scalar { + __fp16 LowerRange; + __fp16 UpperRange; + __fp16 alpha_7; + __fp16 alpha_5; + __fp16 alpha_3; + __fp16 alpha_1; + __fp16 beta_6; + __fp16 beta_4; + __fp16 beta_2; + __fp16 beta_0; +}; + +constexpr MlasTanhConstants_fp16_scalar TanhConstantsFp16 = { + -3.515625f, + 3.515625f, + 5.960464477539063e-08f, + 1.4841556549072266e-05f, + 0.000637054443359375f, + 0.004894256591796875f, + 1.1920928955078125e-06f, + 0.00011855363845825195f, + 0.0022678375244140625f, + 0.004894256591796875f +}; + + +static inline MLAS_SVFLOAT16 +Tanh_Vector_SVE_fp16(MLAS_SVFLOAT16 x, MLAS_SVBOOL pg) +{ + MLAS_SVFLOAT16 g_LowerRange_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.LowerRange); + MLAS_SVFLOAT16 g_UpperRange_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.UpperRange); + MLAS_SVFLOAT16 g_alpha_7_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_7); + MLAS_SVFLOAT16 g_alpha_5_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_5); + MLAS_SVFLOAT16 g_alpha_3_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_3); + MLAS_SVFLOAT16 g_alpha_1_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_1); + MLAS_SVFLOAT16 g_beta_6_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_6); + MLAS_SVFLOAT16 g_beta_4_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_4); + MLAS_SVFLOAT16 g_beta_2_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_2); + MLAS_SVFLOAT16 g_beta_0_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_0); + + x = MlasSveMinfloat16(pg, x, g_UpperRange_vec); + x = MlasSveMaxfloat16(pg, x, g_LowerRange_vec); + + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, x, x); + MLAS_SVFLOAT16 p = MlasSveMLAfloat16(pg, g_alpha_5_vec, g_alpha_7_vec, x2); + p = MlasSveMLAfloat16(pg, g_alpha_3_vec, p, x2); + p = MlasSveMLAfloat16(pg, g_alpha_1_vec, p, x2); + p = MlasSveMulfloat16(pg, p, x); + + svfloat16_t q = MlasSveMLAfloat16(pg, g_beta_4_vec, g_beta_6_vec, x2); + q = MlasSveMLAfloat16(pg, g_beta_2_vec, q, x2); + q = MlasSveMLAfloat16(pg, g_beta_0_vec, q, x2); + + MLAS_SVFLOAT16 res = MlasSveDivfloat16(pg, p, q); + + return res; +} + +void +MlasTanhKernelFp16_SVE(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) +{ + size_t offset = 0; + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + while (offset < N) { + MLAS_SVBOOL pg = MlasSveSelPredictefloat16(offset, N); + MLAS_SVFLOAT16 x = MlasSvereinterpretf16_u16(MlasSveLoadUint16(pg, &input[offset])); + MLAS_SVFLOAT16 y = Tanh_Vector_SVE_fp16(x, pg); + MlasSveStoreUint16(pg, &output[offset], MlasSvereinterpretu16_f16(y)); + offset += svcnth(); + } +} + +static inline MLAS_SVFLOAT16 +exp_neg_rational_approx_f16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x) +{ + const __fp16 a0 = 6.0f; + MLAS_SVFLOAT16 max_x = MlasSveBroadcastfloat16(a0); + x = MlasSveMinfloat16(pg, x, max_x); + + const __fp16 c0 = 1.330f; + const __fp16 c1 = -0.390f; + const __fp16 c2 = 0.0288f; + + const __fp16 d0 = 1.338f; + const __fp16 d1 = 0.848f; + const __fp16 d2 = 0.467f; + + MLAS_SVFLOAT16 c0v = MlasSveBroadcastfloat16(c0); + MLAS_SVFLOAT16 c1v = MlasSveBroadcastfloat16(c1); + MLAS_SVFLOAT16 c2v = MlasSveBroadcastfloat16(c2); + MLAS_SVFLOAT16 d0v = MlasSveBroadcastfloat16(d0); + MLAS_SVFLOAT16 d1v = MlasSveBroadcastfloat16(d1); + MLAS_SVFLOAT16 d2v = MlasSveBroadcastfloat16(d2); + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, x, x); + + MLAS_SVFLOAT16 num = MlasSveMLAfloat16(pg, c0v, c1v, x); + num = MlasSveMLAfloat16(pg, num, c2v, x2); + + MLAS_SVFLOAT16 den = MlasSveMLAfloat16(pg, d0v, d1v, x); + den = MlasSveMLAfloat16(pg, den, d2v, x2); + + MLAS_SVFLOAT16 recip = MlasSveReciprocalfloat16(den); + recip = MlasSveMulfloat16(pg, recip, MlasSveReciprocalStepfloat16(den, recip)); + recip = MlasSveMulfloat16(pg, recip, MlasSveReciprocalStepfloat16(den, recip)); + + MLAS_SVFLOAT16 result = MlasSveMulfloat16(pg, num, recip); + return result; +} + +void MLASCALL +MlasSveErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) +{ + const __fp16 p = 0.328f; + const __fp16 a1 = 0.2505f; + const __fp16 a2 = -0.2881f; + const __fp16 a3 = 1.4102f; + const __fp16 a4 = -1.423f; + const __fp16 a5 = 1.0547f; + + MLAS_SVFLOAT16 vp = MlasSveBroadcastfloat16(p); + MLAS_SVFLOAT16 va1 = MlasSveBroadcastfloat16(a1); + MLAS_SVFLOAT16 va2 = MlasSveBroadcastfloat16(a2); + MLAS_SVFLOAT16 va3 = MlasSveBroadcastfloat16(a3); + MLAS_SVFLOAT16 va4 = MlasSveBroadcastfloat16(a4); + MLAS_SVFLOAT16 va5 = MlasSveBroadcastfloat16(a5); + + const __fp16 v1 = 1.0f; + const __fp16 v2 = -1.0f; + const __fp16 v3 = 0.0f; + const __fp16 v4 = 4.0f; + MLAS_SVFLOAT16 vone = MlasSveBroadcastfloat16(v1); + MLAS_SVFLOAT16 vneg_one = MlasSveBroadcastfloat16(v2); + MLAS_SVFLOAT16 vzero = MlasSveBroadcastfloat16(v3); + MLAS_SVFLOAT16 vth = MlasSveBroadcastfloat16(v4); + + size_t i = 0; + while (i < N) { + MLAS_SVBOOL pg = MlasSveSelPredictefloat16(i, N); + MLAS_SVFLOAT16 x = MlasSvereinterpretf16_u16(MlasSveLoadUint16(pg, &Input[i])); + MLAS_SVBOOL neg_mask = MlasSveComparelessthanfloat16(pg, x, vzero); + MLAS_SVFLOAT16 sign = MlasSveSelectfloat16(neg_mask, vneg_one, vone); + MLAS_SVFLOAT16 absx = MlasSveAbsolutefloat16(MlasSveBroadcastfloat16(v3), pg, x); + svbool_t use_mask = MlasSveComparelessthanfloat16(pg, absx, vth); + MLAS_SVFLOAT16 absx_clamped = MlasSveMinfloat16(pg, absx, vth); + MLAS_SVFLOAT16 denom = MlasSveMLAfloat16(pg, vone, vp, absx_clamped); + MLAS_SVFLOAT16 t = MlasSveReciprocalfloat16(denom); + t = MlasSveMulfloat16(pg, t, MlasSveReciprocalStepfloat16(denom, t)); + t = MlasSveMulfloat16(pg, t, MlasSveReciprocalStepfloat16(denom, t)); + MLAS_SVFLOAT16 t2 = MlasSveMulfloat16(pg, t, t); + MLAS_SVFLOAT16 t3 = MlasSveMulfloat16(pg, t2, t); + MLAS_SVFLOAT16 t4 = MlasSveMulfloat16(pg, t3, t); + MLAS_SVFLOAT16 t5 = MlasSveMulfloat16(pg, t4, t); + svfloat16_t poly = MlasSveMulfloat16(pg, va1, t); + poly = MlasSveMLAfloat16(pg, poly, va2, t2); + poly = MlasSveMLAfloat16(pg, poly, va3, t3); + poly = MlasSveMLAfloat16(pg, poly, va4, t4); + poly = MlasSveMLAfloat16(pg, poly, va5, t5); + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, absx_clamped, absx_clamped); + MLAS_SVFLOAT16 exp_neg_x2 = exp_neg_rational_approx_f16(pg, x2); + MLAS_SVFLOAT16 poly_mul_exp = MlasSveMulfloat16(pg, poly, exp_neg_x2); + MLAS_SVFLOAT16 one_minus_term = MlasSveSubtractfloat16(pg, vone, poly_mul_exp); + MLAS_SVFLOAT16 erf_approx = MlasSveMulfloat16(pg, sign, one_minus_term); + erf_approx = MlasSveMinfloat16(pg, erf_approx, vone); + erf_approx = MlasSveMaxfloat16(pg, erf_approx, vneg_one); + MLAS_SVFLOAT16 result = MlasSveSelectfloat16(use_mask, erf_approx, sign); + MlasSveStoreUint16(pg, &Output[i], MlasSvereinterpretu16_f16(result)); + i += svcntp_b16(svptrue_b16(), pg); + } +} + +void MLASCALL +ComputeGeluFp16_SVE(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo) +{ + const __fp16 r1 = 0.5f; + const __fp16 r2 = 1.0f; + const __fp16 r3 = static_cast(M_SQRT1_2); + const __fp16 r4 = 0.7979f; + const __fp16 r5 = 0.03568f; + + const MLAS_SVFLOAT16 v_half = MlasSveBroadcastfloat16(r1); + const MLAS_SVFLOAT16 v_one = MlasSveBroadcastfloat16(r2); + const MLAS_SVFLOAT16 v_sqrt1_2 = MlasSveBroadcastfloat16(r3); + const MLAS_SVFLOAT16 v_B = MlasSveBroadcastfloat16(r4); + const MLAS_SVFLOAT16 v_C = MlasSveBroadcastfloat16(r5); + + const __fp16 c1 = -5.0f; + const __fp16 c2 = 5.0f; + if (algo == "tanh") { + int64_t i = 0; + while (i < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(i, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); + MLAS_SVFLOAT16 v_x2 = MlasSveMulfloat16(pg, v_x, v_x); + MLAS_SVFLOAT16 v_inner = MlasSveMLAfloat16(pg, v_B, v_C, v_x2); + MLAS_SVFLOAT16 v_tanh_arg = MlasSveMulfloat16(pg, v_x, v_inner); + v_tanh_arg = MlasSveMaxfloat16(pg, MlasSveBroadcastfloat16(c1), MlasSveMinfloat16(pg, v_tanh_arg, MlasSveBroadcastfloat16(c2))); + MlasSveStoreF16(pg, &temp[i], v_tanh_arg); + i += svcnth(); + } + + MlasComputeTanh(reinterpret_cast(temp), reinterpret_cast(temp), count); + + int64_t j = 0; + while (j < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(j, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[j]); + MLAS_SVFLOAT16 v_tanh = MlasSveLoadFloat16(pg, &temp[j]); + MLAS_SVFLOAT16 v_result = MlasSveMulfloat16(pg, v_half, MlasSveMulfloat16(pg, v_x, svadd_f16_m(pg, v_one, v_tanh))); + MlasSveStoreF16(pg, &output[j], v_result); + j += svcnth(); + } + } else if (algo == "none") { + int64_t i = 0; + while (i < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(i, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); + MLAS_SVFLOAT16 v_scaled = MlasSveMulfloat16(pg, v_x, v_sqrt1_2); + MlasSveStoreF16(pg, &temp[i], v_scaled); + i += svcnth(); + } + + MlasSveErfKernelFp16(reinterpret_cast(temp), reinterpret_cast<_mlas_fp16_*>(temp), count); + + int64_t j = 0; + while (j < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(j, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[j]); + MLAS_SVFLOAT16 v_erf = MlasSveLoadFloat16(pg, &temp[j]); + MLAS_SVFLOAT16 v_result = MlasSveMulfloat16(pg, v_half, MlasSveMulfloat16(pg, v_x, MlasSveAddfloat16(pg, v_one, v_erf))); + MlasSveStoreF16(pg, &output[j], v_result); + j += svcnth(); + } + } +} diff --git a/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h b/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h new file mode 100644 index 0000000000000..45379b8da1476 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h @@ -0,0 +1,182 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + mlas_sve_fp16.h + +Abstract: + + This module contains the procedure prototypes for the SVE FP16 intrinsics. + +--*/ + +#pragma once +#include +#include // for isnan if needed +#include + +#include "mlasi_sve.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveBroadcastfloat16(__fp16 Value) +{ + return svdup_f16(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMinfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 range) +{ + return svmin_f16_m(pg, x, range); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMaxfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 range) +{ + return svmax_f16_m(pg, x, range); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMulfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svmul_f16_m(pg, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMLAfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y, MLAS_SVFLOAT16 z) +{ + return svmla_f16_m(pg, x, y, z); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveDivfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svdiv_f16_m(pg, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveSelPredictefloat16(size_t x, size_t y) +{ + return svwhilelt_b16(x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSvereinterpretf16_u16(MLAS_SVUINT16 x) +{ + return svreinterpret_f16_u16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT16 +MlasSveLoadUint16(MLAS_SVBOOL pg, const uint16_t* x) +{ + return svld1_u16(pg, x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveLoadFloat16(MLAS_SVBOOL pg, const MLAS_FP16* x) +{ + return svld1_f16(pg, reinterpret_cast(x)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreUint16(MLAS_SVBOOL pg, uint16_t* Buffer, MLAS_SVUINT16 Vector) +{ + return svst1_u16(pg, Buffer, Vector); +} +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreF16(MLAS_SVBOOL pg, MLAS_FP16* Buffer, MLAS_SVFLOAT16 Vector) +{ + return svst1_f16(pg, reinterpret_cast<__fp16*>(Buffer), Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT16 +MlasSvereinterpretu16_f16(MLAS_SVFLOAT16 x) +{ + return svreinterpret_u16_f16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveReciprocalfloat16(MLAS_SVFLOAT16 x) +{ + return svrecpe_f16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveReciprocalStepfloat16(MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svrecps_f16(x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveSelectfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svsel_f16(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveSubtractfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svsub_f16_m(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveComparelessthanfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svcmplt_f16(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveAbsolutefloat16(MLAS_SVFLOAT16 inactive, MLAS_SVBOOL Pred, MLAS_SVFLOAT16 y) +{ + return svabs_f16_m(inactive, Pred, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveAddfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svadd_f16_m(Pred, x, y); +} +#endif diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index 67a4bf453dd05..a0ff9c14a0159 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -32,7 +32,18 @@ typedef svfloat32_t MLAS_SVFLOAT32; typedef svint32_t MLAS_SVINT32; typedef svuint32_t MLAS_SVUINT32; typedef svbool_t MLAS_SVBOOL; +typedef svfloat16_t MLAS_SVFLOAT16; +typedef svuint16_t MLAS_SVUINT16; +using _mlas_fp16_ = uint16_t; + +void +MLASCALL +MlasSveErfKernelFp16( + const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N + ); +void MLASCALL MlasTanhKernelFp16_SVE(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); +void MLASCALL ComputeGeluFp16_SVE(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo); // function decarations MLAS_FORCEINLINE MLAS_SVFLOAT32 diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index 63bd744795535..fcd1492e19ee9 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -22,7 +22,9 @@ Module Name: #include "mlasi.h" #include "softmax.h" - +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif // // Bundles the floating point constants for use by kernels written in assembly. // @@ -193,6 +195,13 @@ MlasComputeTanh( MLAS_FP16* Output, size_t N ) { +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + MlasTanhKernelFp16_SVE(Input, Output, N); + return; + } +#endif + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; if (dispatch == nullptr || dispatch->Tanh_Fp16 == nullptr) { MLAS_THROW_EX(std::runtime_error, "Tanh_Fp16 is not supported."); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 0ad8d1d4fef4d..72f5ddc7fd19a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1228,7 +1228,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); @@ -3269,7 +3270,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { IsNaN)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index b940d71e1165e..12eebc1339751 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -11,6 +11,10 @@ #include "core/util/math.h" #include "core/mlas/inc/mlas.h" +#if defined(MLAS_NEON_INTRINSICS) +#include "core/mlas/lib/erf_neon_fp16.h" +#endif + #include namespace onnxruntime { @@ -2027,7 +2031,6 @@ Status Erf::Compute(OpKernelContext* context) const { int64_t elem_count = X->Shape().Size(); constexpr int64_t length_per_task = 4096; int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - const auto narrow_task_count = onnxruntime::narrow(task_count); // get allocator for temporary buffers @@ -2044,14 +2047,32 @@ Status Erf::Compute(OpKernelContext* context) const { const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; - // allocate temp buffers using ORT allocator +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + MlasSveErfKernelFp16( + reinterpret_cast(p_input), + reinterpret_cast<_mlas_fp16_*>(p_output), + narrow_count); + return; + } +#endif +#if defined(MLAS_NEON_INTRINSICS) + MlasNeonErfKernelFp16( + reinterpret_cast(p_input), + reinterpret_cast<_mlas_fp16_*>(p_output), + narrow_count); + return; +#endif +#else + // Fallback: convert half to float, compute erf, convert back IAllocatorUniquePtr input_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); IAllocatorUniquePtr output_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - // convert, compute, convert back MlasConvertHalfToFloatBuffer(p_input, input_fp32.get(), narrow_count); MlasComputeErf(input_fp32.get(), output_fp32.get(), narrow_count); MlasConvertFloatToHalfBuffer(output_fp32.get(), p_output, narrow_count); +#endif }, 0); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..f7bf01c609ae1 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -11,6 +11,14 @@ #include #include "core/providers/cpu/element_wise_ranged_transform.h" #include "core/providers/cpu/tensor/gelu.h" +#include "core/mlas/lib/fp16_common.h" +#if defined(MLAS_NEON_INTRINSICS) +#include "core/mlas/lib/erf_neon_fp16.h" +#endif + +#ifdef MLAS_USE_SVE +#include "core/mlas/lib/sve/mlasi_sve.h" +#endif using onnxruntime::narrow; using namespace onnxruntime::common; @@ -19,11 +27,17 @@ namespace onnxruntime { // May revisit the implementations to support inplace computation, if needed. -ONNX_CPU_OPERATOR_KERNEL( - Gelu, - 20, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); +#define ADD_TYPED_GELU_OP(data_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + Gelu, \ + 20, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gelu) + +ADD_TYPED_GELU_OP(float); +ADD_TYPED_GELU_OP(MLFloat16); #ifndef DISABLE_CONTRIB_OPS namespace contrib { @@ -105,4 +119,160 @@ Status Gelu::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); } +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +void ComputeGeluFp16_NEON(const MLFloat16* input, MLFloat16* output, MLFloat16* temp, int64_t count, const std::string& algo) { + const float16_t v_half1 = 0.5f; + const float16_t v_one1 = 1.0f; + const float16_t v_sqrt1_21 = static_cast(M_SQRT1_2); + const float16_t v_B1 = 0.7978845608028654f; + const float16_t v_C1 = 0.035677408136300125f; + const float16_t c1 = 5.0f; + const float16_t c2 = -5.0f; + const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1); + const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1); + const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21); + const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1); + const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1); + + int64_t i = 0; + + if (algo == "tanh") { + // Preprocess input into temp[] for tanh + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2 + MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2) + tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1))); + MlasStoref16Float16x8(reinterpret_cast(temp + i), tanh_arg); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + float inner = x * (0.7979f + 0.03568f * x * x); + inner = std::max(-5.0f, std::min(5.0f, inner)); + temp[i] = static_cast(inner); + } + + // Tanh processing + MlasComputeTanh(reinterpret_cast(temp), + reinterpret_cast(temp), + count); + + } else if (algo == "none") { + // Preprocess input into temp[] for erf + for (i = 0; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2); + MlasStoref16Float16x8(reinterpret_cast(temp + i), scaled); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + temp[i] = static_cast(x * 0.70710678f); + } + + // Erf processing + MlasNeonErfKernelFp16(reinterpret_cast(temp), + reinterpret_cast<_mlas_fp16_*>(temp), + count); + } + + // Final GELU output = 0.5 * x * (1 + tanh|erf) + i = 0; + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast(temp + i)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t))); + MlasStoref16Float16x8(reinterpret_cast(output + i), result); + } + + for (; i < count; ++i) { + float x = static_cast(input[i]); + float t = static_cast(temp[i]); + float gelu = 0.5f * x * (1.0f + t); + output[i] = static_cast(gelu); + } +} + +#endif + +template <> +Status Gelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const MLFloat16* input_data = input->Data(); + Tensor* output = context->Output(0, input->Shape()); + MLFloat16* output_data = output->MutableData(); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + int64_t elem_count = input->Shape().Size(); + constexpr int64_t length_per_task = 4096; + int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + + if (approximation_algorithm_ != "tanh" && approximation_algorithm_ != "none") { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); + } + + // Alignment and buffer size for aligned_alloc + constexpr size_t alignment = 64; + size_t buffer_size = elem_count * sizeof(MLFloat16); + size_t aligned_size = ((buffer_size + alignment - 1) / alignment) * alignment; + auto deleter = [](MLFloat16* p) { std::free(p); }; + std::unique_ptr temp_fp16_aligned( + reinterpret_cast(std::aligned_alloc(alignment, aligned_size)), + deleter); + if (temp_fp16_aligned == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate aligned temporary buffer."); + } + + concurrency::ThreadPool::TryBatchParallelFor( + tp, + static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const MLFloat16* p_input = input_data + start; + MLFloat16* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) + MLFloat16* p_temp = temp_fp16_aligned.get() + start; + bool done = false; + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + ComputeGeluFp16_SVE(p_input, p_output, p_temp, count, approximation_algorithm_); + done = true; + } +#endif + +#if defined(MLAS_NEON_INTRINSICS) + if (!done) { + ComputeGeluFp16_NEON(p_input, p_output, p_temp, count, approximation_algorithm_); + done = true; + } +#endif +#else + for (int64_t i = 0; i < count; ++i) { + float x = static_cast(p_input[i]); + float gelu_val; + if (approximation_algorithm_ == "tanh") { + // GELU approx with tanh + const float B = 0.7978845608f; + const float C = 0.044715f * B; + float tanh_arg = x * (B + C * x * x); + float tanh_res = std::tanh(tanh_arg); + gelu_val = 0.5f * x * (1 + tanh_res); + } else { // "none" + gelu_val = 0.5f * x * (1 + std::erf(x * static_cast(M_SQRT1_2))); + } + p_output[i] = MLFloat16(gelu_val); + } +#endif + }, + 0); + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d711e050fb913..b0a4ce3ed6599 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -752,6 +752,58 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { {}, {{"approximate", "tanh"}}, true, 20); } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +TEST_F(ActivationOpTest, Gelu_fp16_tanh) { + OpTester test("Gelu", 20); + auto formula = [](float x) { + return 0.5f * x * (1 + tanhf(0.7978845608028654f * (x + 0.044715f * x * x * x))); + }; + const std::vector X = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + std::vector Y; + Y.reserve(X.size()); + for (float x : X) { + Y.push_back(formula(x)); + } + std::vector dims{static_cast(X.size())}; + + std::vector f_X(X.size()); + std::vector f_Y(Y.size()); + ConvertFloatToMLFloat16(X.data(), f_X.data(), static_cast(X.size())); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), static_cast(Y.size())); + + test.AddInput("X", dims, f_X); + test.AddOutput("Y", dims, f_Y); + test.AddAttribute("approximate", "tanh"); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST_F(ActivationOpTest, Gelu_fp16_erf) { + OpTester test("Gelu", 20); + auto formula = [](float x) { + return static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2))); + }; + const std::vector X = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + std::vector Y; + Y.reserve(X.size()); + for (float x : X) { + Y.push_back(formula(x)); + } + std::vector dims{static_cast(X.size())}; + + std::vector f_X(X.size()); + std::vector f_Y(Y.size()); + ConvertFloatToMLFloat16(X.data(), f_X.data(), static_cast(X.size())); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), static_cast(Y.size())); + + test.AddInput("X", dims, f_X); + test.AddOutput("Y", dims, f_Y); + test.AddAttribute("approximate", "none"); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} +#endif #endif } // namespace test From 937cd8263a43dc48f1fd1aeb6fe3e6e6c9fe305a Mon Sep 17 00:00:00 2001 From: akote123 Date: Mon, 22 Dec 2025 21:27:19 +0530 Subject: [PATCH 2/7] Resolve Review Comments --- cmake/onnxruntime_mlas.cmake | 9 +- cmake/onnxruntime_providers_cpu.cmake | 9 +- onnxruntime/core/mlas/inc/mlas.h | 29 +++++ onnxruntime/core/mlas/lib/erf.cpp | 49 +++++++ onnxruntime/core/mlas/lib/erf_neon_fp16.cpp | 4 +- onnxruntime/core/mlas/lib/erf_neon_fp16.h | 2 +- onnxruntime/core/mlas/lib/gelu.cpp | 68 ++++++++++ onnxruntime/core/mlas/lib/gelu.h | 32 +++++ onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp | 93 +++++++++++++ ..._sve_fp16.cpp => elementwise_sve_fp16.cpp} | 21 ++- onnxruntime/core/mlas/lib/sve/mlasi_sve.h | 28 +++- onnxruntime/core/mlas/lib/tanh.cpp | 2 +- .../providers/cpu/math/element_wise_ops.cc | 34 +---- onnxruntime/core/providers/cpu/tensor/gelu.cc | 123 +----------------- 14 files changed, 318 insertions(+), 185 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/gelu.cpp create mode 100644 onnxruntime/core/mlas/lib/gelu.h create mode 100644 onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp rename onnxruntime/core/mlas/lib/sve/{Elementwise_sve_fp16.cpp => elementwise_sve_fp16.cpp} (93%) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d0efa2abe4a6e..35033343a173e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -54,6 +54,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/rotary_embedding.cpp ${MLAS_SRC_DIR}/softmax.h ${MLAS_SRC_DIR}/saturation_check.cpp + ${MLAS_SRC_DIR}/gelu.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -118,6 +119,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp ${MLAS_SRC_DIR}/erf_neon_fp16.h ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -483,15 +485,16 @@ else() ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp ${MLAS_SRC_DIR}/erf_neon_fp16.h ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) # Conditionally add the SVE implementation if compiler supports it if (onnxruntime_USE_SVE) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) - list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp) set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) endif() @@ -529,6 +532,7 @@ else() ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -546,6 +550,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index 79feaaec00b75..f77a5dd78fcc5 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -182,14 +182,7 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS) set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") endif() -if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64") -set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16) -endif() -target_include_directories(onnxruntime_providers PRIVATE - ${ONNXRUNTIME_ROOT} - ${ONNXRUNTIME_ROOT}/core/mlas/inc -) - +target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT}) onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen) add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index db7ec288001f9..f9370dc1fa380 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2127,3 +2127,32 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +/** + * @brief Function to override the packing mechanism decision if kleidi ai is included + * @param enable enable kleidiai packing (allow or disallow depending on true/false) + * @return +*/ +void +MLASCALL +MlasGemmBatchPackUseKleidi(bool enable); +#endif + +void +MLASCALL +MlasComputeFP16Erf( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +void +MLASCALL +MlasComputeFP16Gelu( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo +); diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index f9724062e1f4d..e150739215232 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -22,6 +22,15 @@ Module Name: --*/ #include "mlasi.h" + +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif + +#if defined(MLAS_NEON_INTRINSICS) +#include "erf_neon_fp16.h" +#endif + // // Bundles the constants for use by kernels written in assembly. // @@ -266,3 +275,43 @@ Return Value: MlasErfKernel(Input, Output, N); #endif } + +void +MLASCALL +MlasComputeFP16Erf( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N + ) +{ +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + MlasSveErfF16Kernel( + reinterpret_cast(Input), + reinterpret_cast<_mlas_fp16_*>(Output), + N + ); + return; + } +#endif + +#if defined(MLAS_NEON_INTRINSICS) + MlasNeonErfF16Kernel( + reinterpret_cast(Input), + reinterpret_cast<_mlas_fp16_*>(Output), + N + ); + return; +#endif + +#else + std::vector input_fp32(N); + std::vector output_fp32(N); + + MlasConvertHalfToFloatBuffer(Input, input_fp32.data(), N); + MlasComputeErf(input_fp32.data(), output_fp32.data(), N); + MlasConvertFloatToHalfBuffer(output_fp32.data(), Output, N); +#endif +} diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp index 60c7695346a54..ba0068f92eadd 100644 --- a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp @@ -67,7 +67,7 @@ exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) } void -MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) +MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) { const float16_t p = 0.328f; const float16_t a1 = 0.2505f; @@ -144,4 +144,4 @@ MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) Output[i] = float_to_fp16(erf_approx); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.h b/onnxruntime/core/mlas/lib/erf_neon_fp16.h index 154ef0104baa3..c8fd77e0e62d1 100644 --- a/onnxruntime/core/mlas/lib/erf_neon_fp16.h +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.h @@ -21,4 +21,4 @@ Module Name: #include "softmax_kernel_neon.h" using _mlas_fp16_ = uint16_t; -void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N); +void MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N); diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp new file mode 100644 index 0000000000000..a9e454925c4c6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -0,0 +1,68 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + Gelu.cpp + +Abstract: + + This module contains Gelu helper functions . + +--*/ + +#include "gelu.h" + + +void +MLASCALL +MlasComputeFP16Gelu(const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo) +{ +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) + + bool done = false; + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + MlasSveGeluF16Kernel(input, output, temp, count, algo); + done = true; + } +#endif + +#if defined(MLAS_NEON_INTRINSICS) + if (!done) { + MlasNeonGeluF16Kernel(input, output, temp, count, algo); + done = true; + } +#endif + +#else + + (void)temp; + for (int64_t i = 0; i < count; ++i) { + float x = static_cast(input[i]); + float gelu_val; + + if (algo == "tanh") { + // GELU approximation (tanh) + const float B = 0.7978845608f; + const float C = 0.044715f * B; + float tanh_arg = x * (B + C * x * x); + float tanh_res = std::tanh(tanh_arg); + gelu_val = 0.5f * x * (1.0f + tanh_res); + } else { + // GELU exact (erf) + gelu_val = 0.5f * x * + (1.0f + std::erf(x * static_cast(M_SQRT1_2))); + } + + output[i] = MLAS_FP16(gelu_val); + } + +#endif +} diff --git a/onnxruntime/core/mlas/lib/gelu.h b/onnxruntime/core/mlas/lib/gelu.h new file mode 100644 index 0000000000000..727d15302d509 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.h @@ -0,0 +1,32 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + Gelu.cpp + +Abstract: + + This module contains Gelu helper functions . + +--*/ + +#include "fp16_common.h" +#if defined(MLAS_NEON_INTRINSICS) +#include "erf_neon_fp16.h" +#endif + +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif + +void +MLASCALL +MlasNeonGeluF16Kernel( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo +); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp new file mode 100644 index 0000000000000..c1940cc2eb728 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp @@ -0,0 +1,93 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + Gelu.cpp + +Abstract: + + This module contains Gelu helper functions . + +--*/ +#include "gelu.h" + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +void +MLASCALL +MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo) +{ + const float16_t v_half1 = 0.5f; + const float16_t v_one1 = 1.0f; + const float16_t v_sqrt1_21 = static_cast(M_SQRT1_2); + const float16_t v_B1 = 0.7978845608028654f; + const float16_t v_C1 = 0.035677408136300125f; + const float16_t c1 = 5.0f; + const float16_t c2 = -5.0f; + const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1); + const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1); + const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21); + const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1); + const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1); + + int64_t i = 0; + + if (algo == "tanh") { + // Preprocess input into temp[] for tanh + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2 + MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2) + tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1))); + MlasStoref16Float16x8(reinterpret_cast(temp + i), tanh_arg); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + float inner = x * (0.7979f + 0.03568f * x * x); + inner = std::max(-5.0f, std::min(5.0f, inner)); + temp[i] = static_cast(inner); + } + + // Tanh processing + MlasComputeTanh(temp, temp, count); + + } else if (algo == "none") { + // Preprocess input into temp[] for erf + for (i = 0; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2); + MlasStoref16Float16x8(reinterpret_cast(temp + i), scaled); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + temp[i] = static_cast(x * 0.70710678f); + } + + // Erf processing + MlasNeonErfF16Kernel(reinterpret_cast(temp), reinterpret_cast<_mlas_fp16_*>(temp), count); + } + + // Final GELU output = 0.5 * x * (1 + tanh|erf) + i = 0; + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast(temp + i)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t))); + MlasStoref16Float16x8(reinterpret_cast(output + i), result); + } + + for (; i < count; ++i) { + float x = static_cast(input[i]); + float t = static_cast(temp[i]); + float gelu = 0.5f * x * (1.0f + t); + output[i] = static_cast(gelu); + } +} +#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp similarity index 93% rename from onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp rename to onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp index 2530c712a8866..e6fadd9dbc9ed 100644 --- a/onnxruntime/core/mlas/lib/sve/Elementwise_sve_fp16.cpp +++ b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp @@ -29,17 +29,16 @@ struct MlasTanhConstants_fp16_scalar { constexpr MlasTanhConstants_fp16_scalar TanhConstantsFp16 = { -3.515625f, 3.515625f, - 5.960464477539063e-08f, + 5.960464477539063e-08f, 1.4841556549072266e-05f, - 0.000637054443359375f, + 0.000637054443359375f, 0.004894256591796875f, - 1.1920928955078125e-06f, + 1.1920928955078125e-06f, 0.00011855363845825195f, 0.0022678375244140625f, - 0.004894256591796875f + 0.004894256591796875f }; - static inline MLAS_SVFLOAT16 Tanh_Vector_SVE_fp16(MLAS_SVFLOAT16 x, MLAS_SVBOOL pg) { @@ -53,7 +52,7 @@ Tanh_Vector_SVE_fp16(MLAS_SVFLOAT16 x, MLAS_SVBOOL pg) MLAS_SVFLOAT16 g_beta_4_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_4); MLAS_SVFLOAT16 g_beta_2_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_2); MLAS_SVFLOAT16 g_beta_0_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_0); - + x = MlasSveMinfloat16(pg, x, g_UpperRange_vec); x = MlasSveMaxfloat16(pg, x, g_LowerRange_vec); @@ -73,7 +72,7 @@ Tanh_Vector_SVE_fp16(MLAS_SVFLOAT16 x, MLAS_SVBOOL pg) } void -MlasTanhKernelFp16_SVE(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) +MlasSveTanhF16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { size_t offset = 0; const auto* input = reinterpret_cast(Input); @@ -125,7 +124,7 @@ exp_neg_rational_approx_f16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x) } void MLASCALL -MlasSveErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) +MlasSveErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) { const __fp16 p = 0.328f; const __fp16 a1 = 0.2505f; @@ -186,7 +185,7 @@ MlasSveErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) } void MLASCALL -ComputeGeluFp16_SVE(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo) +MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo) { const __fp16 r1 = 0.5f; const __fp16 r2 = 1.0f; @@ -215,7 +214,7 @@ ComputeGeluFp16_SVE(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, i += svcnth(); } - MlasComputeTanh(reinterpret_cast(temp), reinterpret_cast(temp), count); + MlasSveTanhF16Kernel(reinterpret_cast(temp), reinterpret_cast(temp), count); int64_t j = 0; while (j < (count)) { @@ -236,7 +235,7 @@ ComputeGeluFp16_SVE(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, i += svcnth(); } - MlasSveErfKernelFp16(reinterpret_cast(temp), reinterpret_cast<_mlas_fp16_*>(temp), count); + MlasSveErfF16Kernel(reinterpret_cast(temp), reinterpret_cast<_mlas_fp16_*>(temp), count); int64_t j = 0; while (j < (count)) { diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index a0ff9c14a0159..ec53172a83325 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -39,11 +39,29 @@ using _mlas_fp16_ = uint16_t; void MLASCALL -MlasSveErfKernelFp16( - const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N - ); -void MLASCALL MlasTanhKernelFp16_SVE(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); -void MLASCALL ComputeGeluFp16_SVE(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo); +MlasSveErfF16Kernel( + const _mlas_fp16_* Input, + _mlas_fp16_* Output, + size_t N +); + +void +MLASCALL +MlasSveTanhF16Kernel( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +void +MLASCALL +MlasSveGeluF16Kernel( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo +); // function decarations MLAS_FORCEINLINE MLAS_SVFLOAT32 diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index fcd1492e19ee9..7fb0336e527ae 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -197,7 +197,7 @@ MlasComputeTanh( ) { #if defined(MLAS_USE_SVE) if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { - MlasTanhKernelFp16_SVE(Input, Output, N); + MlasSveTanhF16Kernel(Input, Output, N); return; } #endif diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 12eebc1339751..8e356b14fe872 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -11,10 +11,6 @@ #include "core/util/math.h" #include "core/mlas/inc/mlas.h" -#if defined(MLAS_NEON_INTRINSICS) -#include "core/mlas/lib/erf_neon_fp16.h" -#endif - #include namespace onnxruntime { @@ -2042,37 +2038,9 @@ Status Erf::Compute(OpKernelContext* context) const { [&](ptrdiff_t task_idx) { const auto start = task_idx * length_per_task; const int64_t count = std::min(length_per_task, elem_count - start); - const auto narrow_count = onnxruntime::narrow(count); - const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; - -#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) -#if defined(MLAS_USE_SVE) - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { - MlasSveErfKernelFp16( - reinterpret_cast(p_input), - reinterpret_cast<_mlas_fp16_*>(p_output), - narrow_count); - return; - } -#endif -#if defined(MLAS_NEON_INTRINSICS) - MlasNeonErfKernelFp16( - reinterpret_cast(p_input), - reinterpret_cast<_mlas_fp16_*>(p_output), - narrow_count); - return; -#endif -#else - // Fallback: convert half to float, compute erf, convert back - IAllocatorUniquePtr input_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - IAllocatorUniquePtr output_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - - MlasConvertHalfToFloatBuffer(p_input, input_fp32.get(), narrow_count); - MlasComputeErf(input_fp32.get(), output_fp32.get(), narrow_count); - MlasConvertFloatToHalfBuffer(output_fp32.get(), p_output, narrow_count); -#endif + MlasComputeFP16Erf(p_input, p_output, count); }, 0); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index f7bf01c609ae1..078149346dc72 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -11,14 +11,6 @@ #include #include "core/providers/cpu/element_wise_ranged_transform.h" #include "core/providers/cpu/tensor/gelu.h" -#include "core/mlas/lib/fp16_common.h" -#if defined(MLAS_NEON_INTRINSICS) -#include "core/mlas/lib/erf_neon_fp16.h" -#endif - -#ifdef MLAS_USE_SVE -#include "core/mlas/lib/sve/mlasi_sve.h" -#endif using onnxruntime::narrow; using namespace onnxruntime::common; @@ -119,87 +111,6 @@ Status Gelu::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); } -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - -void ComputeGeluFp16_NEON(const MLFloat16* input, MLFloat16* output, MLFloat16* temp, int64_t count, const std::string& algo) { - const float16_t v_half1 = 0.5f; - const float16_t v_one1 = 1.0f; - const float16_t v_sqrt1_21 = static_cast(M_SQRT1_2); - const float16_t v_B1 = 0.7978845608028654f; - const float16_t v_C1 = 0.035677408136300125f; - const float16_t c1 = 5.0f; - const float16_t c2 = -5.0f; - const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1); - const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1); - const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21); - const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1); - const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1); - - int64_t i = 0; - - if (algo == "tanh") { - // Preprocess input into temp[] for tanh - for (; i + 7 < count; i += 8) { - MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); - MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); - MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2 - MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2) - tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1))); - MlasStoref16Float16x8(reinterpret_cast(temp + i), tanh_arg); - } - - // Tail - for (; i < count; ++i) { - float x = static_cast(input[i]); - float inner = x * (0.7979f + 0.03568f * x * x); - inner = std::max(-5.0f, std::min(5.0f, inner)); - temp[i] = static_cast(inner); - } - - // Tanh processing - MlasComputeTanh(reinterpret_cast(temp), - reinterpret_cast(temp), - count); - - } else if (algo == "none") { - // Preprocess input into temp[] for erf - for (i = 0; i + 7 < count; i += 8) { - MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); - MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2); - MlasStoref16Float16x8(reinterpret_cast(temp + i), scaled); - } - - // Tail - for (; i < count; ++i) { - float x = static_cast(input[i]); - temp[i] = static_cast(x * 0.70710678f); - } - - // Erf processing - MlasNeonErfKernelFp16(reinterpret_cast(temp), - reinterpret_cast<_mlas_fp16_*>(temp), - count); - } - - // Final GELU output = 0.5 * x * (1 + tanh|erf) - i = 0; - for (; i + 7 < count; i += 8) { - MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); - MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast(temp + i)); - MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t))); - MlasStoref16Float16x8(reinterpret_cast(output + i), result); - } - - for (; i < count; ++i) { - float x = static_cast(input[i]); - float t = static_cast(temp[i]); - float gelu = 0.5f * x * (1.0f + t); - output[i] = static_cast(gelu); - } -} - -#endif - template <> Status Gelu::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); @@ -235,41 +146,9 @@ Status Gelu::Compute(OpKernelContext* context) const { const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - -#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) MLFloat16* p_temp = temp_fp16_aligned.get() + start; - bool done = false; - -#if defined(MLAS_USE_SVE) - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { - ComputeGeluFp16_SVE(p_input, p_output, p_temp, count, approximation_algorithm_); - done = true; - } -#endif + MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_); -#if defined(MLAS_NEON_INTRINSICS) - if (!done) { - ComputeGeluFp16_NEON(p_input, p_output, p_temp, count, approximation_algorithm_); - done = true; - } -#endif -#else - for (int64_t i = 0; i < count; ++i) { - float x = static_cast(p_input[i]); - float gelu_val; - if (approximation_algorithm_ == "tanh") { - // GELU approx with tanh - const float B = 0.7978845608f; - const float C = 0.044715f * B; - float tanh_arg = x * (B + C * x * x); - float tanh_res = std::tanh(tanh_arg); - gelu_val = 0.5f * x * (1 + tanh_res); - } else { // "none" - gelu_val = 0.5f * x * (1 + std::erf(x * static_cast(M_SQRT1_2))); - } - p_output[i] = MLFloat16(gelu_val); - } -#endif }, 0); return Status::OK(); From cf6d83f38637206c352c7539fdbc3df3b8d5add9 Mon Sep 17 00:00:00 2001 From: Sanket Kale Date: Thu, 22 Jan 2026 13:37:33 +0530 Subject: [PATCH 3/7] Resolved Copilot comments --- cmake/onnxruntime_mlas.cmake | 4 +- onnxruntime/core/mlas/lib/erf_neon_fp16.cpp | 8 +-- onnxruntime/core/mlas/lib/fp16_common.h | 2 +- onnxruntime/core/mlas/lib/gelu.cpp | 21 +------- onnxruntime/core/mlas/lib/gelu.h | 17 ++++--- onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp | 10 ++-- .../mlas/lib/sve/elementwise_sve_fp16.cpp | 8 +-- onnxruntime/core/mlas/lib/sve/mlasi_sve.h | 12 ++--- onnxruntime/core/providers/cpu/tensor/gelu.cc | 49 ++++++++++++++++--- 9 files changed, 74 insertions(+), 57 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 35033343a173e..3cae38e5bae05 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -549,8 +549,8 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp index ba0068f92eadd..450aad217fa49 100644 --- a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp @@ -60,8 +60,8 @@ exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v); den = MlasMultiplyAddFloat16(d2v, x2, den); MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den); - recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip)); - recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip)); + recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); + recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip); return result; } @@ -103,8 +103,8 @@ MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth); MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone); MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom); - t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t)); - t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t)); + t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); + t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t); MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t); MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t); diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 71655f9b59905..31ad706fa26e2 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -596,7 +596,7 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector) MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasReciprocalSqrtFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasReciprocalStepFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vrecpsq_f16(Vector1, Vector2); } diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index a9e454925c4c6..a4daec1ee5298 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -8,13 +8,12 @@ Module Name: Abstract: - This module contains Gelu helper functions . + This module contains Gelu helper functions. --*/ #include "gelu.h" - void MLASCALL MlasComputeFP16Gelu(const MLAS_FP16* input, @@ -23,26 +22,11 @@ MlasComputeFP16Gelu(const MLAS_FP16* input, int64_t count, const std::string& algo) { -#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) - - bool done = false; - #if defined(MLAS_USE_SVE) - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { MlasSveGeluF16Kernel(input, output, temp, count, algo); - done = true; - } -#endif - -#if defined(MLAS_NEON_INTRINSICS) - if (!done) { +#elif defined(MLAS_NEON_INTRINSICS) MlasNeonGeluF16Kernel(input, output, temp, count, algo); - done = true; - } -#endif - #else - (void)temp; for (int64_t i = 0; i < count; ++i) { float x = static_cast(input[i]); @@ -63,6 +47,5 @@ MlasComputeFP16Gelu(const MLAS_FP16* input, output[i] = MLAS_FP16(gelu_val); } - #endif } diff --git a/onnxruntime/core/mlas/lib/gelu.h b/onnxruntime/core/mlas/lib/gelu.h index 727d15302d509..b3c96091d58a1 100644 --- a/onnxruntime/core/mlas/lib/gelu.h +++ b/onnxruntime/core/mlas/lib/gelu.h @@ -4,22 +4,17 @@ Copyright 2025 FUJITSU LIMITED Module Name: - Gelu.cpp + gelu.h Abstract: - This module contains Gelu helper functions . + This module contains Gelu helper functions . --*/ #include "fp16_common.h" #if defined(MLAS_NEON_INTRINSICS) #include "erf_neon_fp16.h" -#endif - -#ifdef MLAS_USE_SVE -#include "sve/mlasi_sve.h" -#endif void MLASCALL @@ -29,4 +24,10 @@ MlasNeonGeluF16Kernel( MLAS_FP16* temp, int64_t count, const std::string& algo -); \ No newline at end of file +); + +#endif + +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif diff --git a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp index c1940cc2eb728..8dc356a347baa 100644 --- a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp @@ -4,7 +4,7 @@ Copyright 2025 FUJITSU LIMITED Module Name: - Gelu.cpp + gelu_neon_fp16.cpp Abstract: @@ -12,7 +12,7 @@ Module Name: --*/ #include "gelu.h" - +#include #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) void @@ -48,7 +48,7 @@ MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp // Tail for (; i < count; ++i) { float x = static_cast(input[i]); - float inner = x * (0.7979f + 0.03568f * x * x); + float inner = x * (0.7978845608028654f + 0.035677408136300125f * x * x); inner = std::max(-5.0f, std::min(5.0f, inner)); temp[i] = static_cast(inner); } @@ -56,7 +56,7 @@ MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp // Tanh processing MlasComputeTanh(temp, temp, count); - } else if (algo == "none") { + } else{ // Preprocess input into temp[] for erf for (i = 0; i + 7 < count; i += 8) { MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); @@ -67,7 +67,7 @@ MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp // Tail for (; i < count; ++i) { float x = static_cast(input[i]); - temp[i] = static_cast(x * 0.70710678f); + temp[i] = static_cast(x * static_cast(M_SQRT1_2)); } // Erf processing diff --git a/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp index e6fadd9dbc9ed..573a72dc2117d 100644 --- a/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp +++ b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp @@ -190,8 +190,8 @@ MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, const __fp16 r1 = 0.5f; const __fp16 r2 = 1.0f; const __fp16 r3 = static_cast(M_SQRT1_2); - const __fp16 r4 = 0.7979f; - const __fp16 r5 = 0.03568f; + const __fp16 r4 = 0.7978845608028654f; + const __fp16 r5 = 0.035677408136300125f; const MLAS_SVFLOAT16 v_half = MlasSveBroadcastfloat16(r1); const MLAS_SVFLOAT16 v_one = MlasSveBroadcastfloat16(r2); @@ -203,7 +203,7 @@ MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, const __fp16 c2 = 5.0f; if (algo == "tanh") { int64_t i = 0; - while (i < (count)) { + while (i < count) { svbool_t pg = MlasSveSelPredictefloat16(i, count); MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); MLAS_SVFLOAT16 v_x2 = MlasSveMulfloat16(pg, v_x, v_x); @@ -225,7 +225,7 @@ MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, MlasSveStoreF16(pg, &output[j], v_result); j += svcnth(); } - } else if (algo == "none") { + } else { int64_t i = 0; while (i < (count)) { svbool_t pg = MlasSveSelPredictefloat16(i, count); diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index ec53172a83325..010b18b4be9ab 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -56,13 +56,13 @@ MlasSveTanhF16Kernel( void MLASCALL MlasSveGeluF16Kernel( - const MLAS_FP16* input, - MLAS_FP16* output, - MLAS_FP16* temp, - int64_t count, - const std::string& algo + const MLAS_FP16* Input, + MLAS_FP16* Output, + MLAS_FP16* Temp, + int64_t N, + const std::string& Algo ); -// function decarations +// function declarations MLAS_FORCEINLINE MLAS_SVFLOAT32 MlasSveComputeExpVector( diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index 078149346dc72..1c0d6db2c98a8 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -12,6 +12,31 @@ #include "core/providers/cpu/element_wise_ranged_transform.h" #include "core/providers/cpu/tensor/gelu.h" +#include +#include +#include + +#if defined(_WIN32) + #include +#endif + +inline void* AlignedAlloc(size_t alignment, size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + // std::aligned_alloc requires size to be a multiple of alignment + return std::aligned_alloc(alignment, size); +#endif +} + +inline void AlignedFree(void* p) { +#if defined(_WIN32) + _aligned_free(p); +#else + std::free(p); +#endif +} + using onnxruntime::narrow; using namespace onnxruntime::common; @@ -128,16 +153,24 @@ Status Gelu::Compute(OpKernelContext* context) const { // Alignment and buffer size for aligned_alloc constexpr size_t alignment = 64; + size_t buffer_size = elem_count * sizeof(MLFloat16); - size_t aligned_size = ((buffer_size + alignment - 1) / alignment) * alignment; - auto deleter = [](MLFloat16* p) { std::free(p); }; - std::unique_ptr temp_fp16_aligned( - reinterpret_cast(std::aligned_alloc(alignment, aligned_size)), - deleter); - if (temp_fp16_aligned == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate aligned temporary buffer."); + size_t aligned_size = + ((buffer_size + alignment - 1) / alignment) * alignment; + + void* raw = AlignedAlloc(alignment, aligned_size); + if (!raw) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to allocate aligned temporary buffer."); } +auto deleter = [](MLFloat16* p) { + AlignedFree(p); +}; + +std::unique_ptr temp_fp16_aligned( + static_cast(raw), deleter); + concurrency::ThreadPool::TryBatchParallelFor( tp, static_cast(task_count), @@ -147,7 +180,7 @@ Status Gelu::Compute(OpKernelContext* context) const { MLFloat16* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); MLFloat16* p_temp = temp_fp16_aligned.get() + start; - MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_); + MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_); }, 0); From 9126b69f091fb48402cc27a2d809b77c795dccd2 Mon Sep 17 00:00:00 2001 From: Sanket Kale Date: Tue, 3 Feb 2026 14:16:04 +0530 Subject: [PATCH 4/7] Resolved CI failures --- onnxruntime/core/mlas/lib/erf.cpp | 7 ++++--- onnxruntime/core/mlas/lib/gelu.cpp | 4 ++-- onnxruntime/core/mlas/lib/gelu.h | 4 +++- onnxruntime/core/providers/cpu/math/element_wise_ops.cc | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index e150739215232..293ccebbe34a8 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -27,7 +27,7 @@ Module Name: #include "sve/mlasi_sve.h" #endif -#if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) #include "erf_neon_fp16.h" #endif @@ -286,7 +286,7 @@ MlasComputeFP16Erf( { #if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) -#if defined(MLAS_USE_SVE) +#if defined(MLAS_USE_SVE) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { MlasSveErfF16Kernel( reinterpret_cast(Input), @@ -297,7 +297,7 @@ MlasComputeFP16Erf( } #endif -#if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) MlasNeonErfF16Kernel( reinterpret_cast(Input), reinterpret_cast<_mlas_fp16_*>(Output), @@ -315,3 +315,4 @@ MlasComputeFP16Erf( MlasConvertFloatToHalfBuffer(output_fp32.data(), Output, N); #endif } + \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index a4daec1ee5298..869671efd4630 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -22,9 +22,9 @@ MlasComputeFP16Gelu(const MLAS_FP16* input, int64_t count, const std::string& algo) { -#if defined(MLAS_USE_SVE) +#if defined(MLAS_USE_SVE) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) MlasSveGeluF16Kernel(input, output, temp, count, algo); -#elif defined(MLAS_NEON_INTRINSICS) +#elif defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) MlasNeonGeluF16Kernel(input, output, temp, count, algo); #else (void)temp; diff --git a/onnxruntime/core/mlas/lib/gelu.h b/onnxruntime/core/mlas/lib/gelu.h index b3c96091d58a1..554afd5d41575 100644 --- a/onnxruntime/core/mlas/lib/gelu.h +++ b/onnxruntime/core/mlas/lib/gelu.h @@ -12,8 +12,10 @@ Module Name: --*/ +#pragma once + #include "fp16_common.h" -#if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) #include "erf_neon_fp16.h" void diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 8e356b14fe872..9f35204a3ec28 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -2009,7 +2009,7 @@ Status Erf::Compute(OpKernelContext* context) const { const float* p_input = input_data + start; float* p_output = output_data + start; const std::ptrdiff_t count = std::min(length_per_task, elem_count - start); - MlasComputeErf(p_input, p_output, count); + MlasComputeErf(p_input, p_output, static_cast(count)); }, 0); From 9439d8f683ca8d0a73446dbf31e77aacda5f7a26 Mon Sep 17 00:00:00 2001 From: Sanket Kale Date: Tue, 3 Feb 2026 14:21:48 +0530 Subject: [PATCH 5/7] Fixed formatting errors --- onnxruntime/core/providers/cpu/tensor/gelu.cc | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index 1c0d6db2c98a8..21ef3fedad6b1 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -17,7 +17,7 @@ #include #if defined(_WIN32) - #include +#include #endif inline void* AlignedAlloc(size_t alignment, size_t size) { @@ -161,15 +161,15 @@ Status Gelu::Compute(OpKernelContext* context) const { void* raw = AlignedAlloc(alignment, aligned_size); if (!raw) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Failed to allocate aligned temporary buffer."); + "Failed to allocate aligned temporary buffer."); } -auto deleter = [](MLFloat16* p) { - AlignedFree(p); -}; + auto deleter = [](MLFloat16* p) { + AlignedFree(p); + }; -std::unique_ptr temp_fp16_aligned( - static_cast(raw), deleter); + std::unique_ptr temp_fp16_aligned( + static_cast(raw), deleter); concurrency::ThreadPool::TryBatchParallelFor( tp, @@ -181,7 +181,6 @@ std::unique_ptr temp_fp16_aligned( int64_t count = std::min(length_per_task, elem_count - start); MLFloat16* p_temp = temp_fp16_aligned.get() + start; MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_); - }, 0); return Status::OK(); From 4f10c21cc42ae90525cf8225da813eb848f58625 Mon Sep 17 00:00:00 2001 From: Sanket Kale Date: Thu, 5 Feb 2026 19:40:10 +0530 Subject: [PATCH 6/7] Added runtime guards and resolved CIfailures --- onnxruntime/core/mlas/lib/erf.cpp | 24 +++--------------- onnxruntime/core/mlas/lib/erf_neon_fp16.cpp | 3 +++ onnxruntime/core/mlas/lib/gelu.cpp | 8 +++--- onnxruntime/core/mlas/lib/mlasi.h | 25 +++++++++++++++++++ onnxruntime/core/mlas/lib/platform.cpp | 15 +++++++++++ .../providers/cpu/math/element_wise_ops.cc | 2 +- 6 files changed, 51 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index 293ccebbe34a8..193e2efc5fcd6 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -285,27 +285,9 @@ MlasComputeFP16Erf( ) { #if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) - -#if defined(MLAS_USE_SVE) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { - MlasSveErfF16Kernel( - reinterpret_cast(Input), - reinterpret_cast<_mlas_fp16_*>(Output), - N - ); - return; - } -#endif - -#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - MlasNeonErfF16Kernel( - reinterpret_cast(Input), - reinterpret_cast<_mlas_fp16_*>(Output), - N - ); - return; -#endif - + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + GetMlasPlatform().ErfF16KernelRoutine(reinterpret_cast(Input), reinterpret_cast<_mlas_fp16_*>(Output), N); + #endif #else std::vector input_fp32(N); std::vector output_fp32(N); diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp index 450aad217fa49..48d3e54bd9439 100644 --- a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp @@ -14,6 +14,8 @@ Module Name: #include "erf_neon_fp16.h" +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + // Helpers to safely convert between float and FP16-bit representation static float fp16_to_float(uint16_t h) @@ -145,3 +147,4 @@ MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) Output[i] = float_to_fp16(erf_approx); } } +#endif diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index 869671efd4630..a78921821c5f9 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -22,10 +22,10 @@ MlasComputeFP16Gelu(const MLAS_FP16* input, int64_t count, const std::string& algo) { -#if defined(MLAS_USE_SVE) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - MlasSveGeluF16Kernel(input, output, temp, count, algo); -#elif defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - MlasNeonGeluF16Kernel(input, output, temp, count, algo); +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + GetMlasPlatform().GeluF16KernelRoutine(input, output, temp, count, algo); + #endif #else (void)temp; for (int64_t i = 0; i < count; ++i) { diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e75ca3dc90e60..2316126161f30 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,6 +610,25 @@ void size_t N ); +using _mlas_fp16_ = uint16_t; +typedef +void +(MLASCALL MLAS_COMPUTE_ERF_FP16_KERNEL)( + const _mlas_fp16_* Input, + _mlas_fp16_* Output, + size_t N +); + +typedef +void +(MLASCALL MLAS_COMPUTE_GELU_FP16_KERNEL)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + MLAS_FP16* Temp, + int64_t N, + const std::string& Algo +); + typedef float (MLASCALL MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL)( @@ -1057,6 +1076,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; + MLAS_COMPUTE_ERF_FP16_KERNEL MlasNeonErfF16Kernel; + MLAS_COMPUTE_GELU_FP16_KERNEL MlasNeonGeluF16Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; @@ -1410,6 +1431,10 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + MLAS_COMPUTE_ERF_FP16_KERNEL* ErfF16KernelRoutine; + MLAS_COMPUTE_GELU_FP16_KERNEL* GeluF16KernelRoutine; + #endif #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b913b1c3b8c26..69d1515eef145 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -19,6 +19,10 @@ Module Name: #ifdef MLAS_USE_SVE #include "sve/mlasi_sve.h" #endif +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +#include "erf_neon_fp16.h" +#include "gelu.h" +#endif #if defined(USE_KLEIDIAI) #include "kleidiai/mlasi_kleidiai.h" #endif @@ -635,6 +639,17 @@ Return Value: this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; } + + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + this->ErfF16KernelRoutine = MlasSveErfF16Kernel; + this->GeluF16KernelRoutine = MlasSveGeluF16Kernel; + } + else{ + this->ErfF16KernelRoutine = MlasNeonErfF16Kernel; + this->GeluF16KernelRoutine = MlasNeonGeluF16Kernel; + } + #endif #endif // diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 9f35204a3ec28..f4bf60bd9b6a8 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -2040,7 +2040,7 @@ Status Erf::Compute(OpKernelContext* context) const { const int64_t count = std::min(length_per_task, elem_count - start); const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; - MlasComputeFP16Erf(p_input, p_output, count); + MlasComputeFP16Erf(p_input, p_output, static_cast(count)); }, 0); From 9bccbd8738c916fbcf579f844a80851805cfa5c6 Mon Sep 17 00:00:00 2001 From: Sanket Kale Date: Fri, 6 Feb 2026 20:17:12 +0530 Subject: [PATCH 7/7] Resolved MacOS and Web CI failures --- onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp | 2 +- onnxruntime/core/mlas/lib/mlasi.h | 4 +-- onnxruntime/core/providers/cpu/tensor/gelu.cc | 26 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp index 8dc356a347baa..8802f9da7c987 100644 --- a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp @@ -13,7 +13,7 @@ Module Name: --*/ #include "gelu.h" #include -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) void MLASCALL diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 2316126161f30..5a33467e7ce7f 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1431,10 +1431,10 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; - #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +#endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) MLAS_COMPUTE_ERF_FP16_KERNEL* ErfF16KernelRoutine; MLAS_COMPUTE_GELU_FP16_KERNEL* GeluF16KernelRoutine; - #endif #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index 21ef3fedad6b1..abf88392c2315 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -77,9 +77,9 @@ Status Gelu::Compute(OpKernelContext* context) const { T* output_data = output->MutableData(); concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + size_t elem_count = input->Shape().Size(); + constexpr size_t length_per_task = 4096; // this number comes from FastGelu. + size_t task_count = (elem_count + length_per_task - 1) / length_per_task; if (approximation_algorithm_ == "tanh") { // FastGelu allows optional bias. Here we split input data into chunks. Each chunk @@ -95,16 +95,16 @@ Status Gelu::Compute(OpKernelContext* context) const { const auto start = task_idx * length_per_task; const T* p_input = input_data + start; T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); + size_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { T value = p_input[i]; p_output[i] = value * (static_cast(C) * value * value + static_cast(B)); } MlasComputeTanh(p_output, p_output, narrow(count)); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); } }, @@ -117,16 +117,16 @@ Status Gelu::Compute(OpKernelContext* context) const { const auto start = task_idx * length_per_task; const T* p_input = input_data + start; T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); + size_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { T value = p_input[i]; p_output[i] = value * static_cast(M_SQRT1_2); } MlasComputeErf(p_output, p_output, narrow(count)); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); } }, @@ -143,9 +143,9 @@ Status Gelu::Compute(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); MLFloat16* output_data = output->MutableData(); concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + size_t elem_count = input->Shape().Size(); + constexpr size_t length_per_task = 4096; + size_t task_count = (elem_count + length_per_task - 1) / length_per_task; if (approximation_algorithm_ != "tanh" && approximation_algorithm_ != "none") { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); @@ -178,7 +178,7 @@ Status Gelu::Compute(OpKernelContext* context) const { const auto start = task_idx * length_per_task; const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); + size_t count = std::min(length_per_task, elem_count - start); MLFloat16* p_temp = temp_fp16_aligned.get() + start; MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_); },