diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d7dcde945e6d7..3cae38e5bae05 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -54,6 +54,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/rotary_embedding.cpp ${MLAS_SRC_DIR}/softmax.h ${MLAS_SRC_DIR}/saturation_check.cpp + ${MLAS_SRC_DIR}/gelu.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -116,6 +117,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.h + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -479,13 +483,18 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.h + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) # Conditionally add the SVE implementation if compiler supports it if (onnxruntime_USE_SVE) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp) set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) endif() @@ -522,6 +531,8 @@ else() ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -538,6 +549,8 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index db7ec288001f9..f9370dc1fa380 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2127,3 +2127,32 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +/** + * @brief Function to override the packing mechanism decision if kleidi ai is included + * @param enable enable kleidiai packing (allow or disallow depending on true/false) + * @return +*/ +void +MLASCALL +MlasGemmBatchPackUseKleidi(bool enable); +#endif + +void +MLASCALL +MlasComputeFP16Erf( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +void +MLASCALL +MlasComputeFP16Gelu( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo +); diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index f9724062e1f4d..193e2efc5fcd6 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -22,6 +22,15 @@ Module Name: --*/ #include "mlasi.h" + +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif + +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +#include "erf_neon_fp16.h" +#endif + // // Bundles the constants for use by kernels written in assembly. // @@ -266,3 +275,26 @@ Return Value: MlasErfKernel(Input, Output, N); #endif } + +void +MLASCALL +MlasComputeFP16Erf( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N + ) +{ +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + GetMlasPlatform().ErfF16KernelRoutine(reinterpret_cast(Input), reinterpret_cast<_mlas_fp16_*>(Output), N); + #endif +#else + std::vector input_fp32(N); + std::vector output_fp32(N); + + MlasConvertHalfToFloatBuffer(Input, input_fp32.data(), N); + MlasComputeErf(input_fp32.data(), output_fp32.data(), N); + MlasConvertFloatToHalfBuffer(output_fp32.data(), Output, N); +#endif +} + \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp new file mode 100644 index 0000000000000..48d3e54bd9439 --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp @@ -0,0 +1,150 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + erf_neon_fp16.cpp + +Abstract: + + This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. + +--*/ + +#include "erf_neon_fp16.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + +// Helpers to safely convert between float and FP16-bit representation +static float +fp16_to_float(uint16_t h) +{ + __fp16 tmp; + memcpy(&tmp, &h, sizeof(h)); + return (float)tmp; +} + +static uint16_t +float_to_fp16(float f) +{ + __fp16 tmp = (__fp16)f; + uint16_t h; + memcpy(&h, &tmp, sizeof(h)); + return h; +} + +static inline MLAS_FLOAT16X8 +exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) +{ + const float16_t a0 = 6.0f; + MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0); + x = MlasMinimumFloat16(x, max_x); + + const float16_t c0 = 1.330f; + const float16_t c1 = -0.390f; + const float16_t c2 = 0.0288f; + + const float16_t d0 = 1.338f; + const float16_t d1 = 0.848f; + const float16_t d2 = 0.467f; + + MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0); + MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1); + MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2); + + MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0); + MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1); + MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v); + num = MlasMultiplyAddFloat16(c2v, x2, num); + MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v); + den = MlasMultiplyAddFloat16(d2v, x2, den); + MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den); + recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); + recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip); + return result; +} + +void +MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) +{ + const float16_t p = 0.328f; + const float16_t a1 = 0.2505f; + const float16_t a2 = -0.2881f; + const float16_t a3 = 1.4102f; + const float16_t a4 = -1.423f; + const float16_t a5 = 1.0547f; + + MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p); + MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1); + MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2); + MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3); + MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4); + MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5); + + constexpr float16_t one_fp16 = 1.0f; + constexpr float16_t neg_one_fp16 = -1.0f; + constexpr float16_t zero_fp16 = 0.0f; + constexpr float16_t four_fp16 = 4.0f; + + MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16); + MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16); + MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16); + MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16); + + size_t i = 0; + for (; i + 8 <= N; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&Input[i]); + MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero); + MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone); + MLAS_FLOAT16X8 absx = MlasAbsFloat16(x); + MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth); + MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth); + MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone); + MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom); + t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); + t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); + MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t); + MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t); + MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t); + MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t); + MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t); + poly = MlasMultiplyAddFloat16(va2, t2, poly); + poly = MlasMultiplyAddFloat16(va3, t3, poly); + poly = MlasMultiplyAddFloat16(va4, t4, poly); + poly = MlasMultiplyAddFloat16(va5, t5, poly); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped); + MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2); + MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2); + MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp); + MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term); + erf_approx = MlasMinimumFloat16(erf_approx, vone); + erf_approx = MlasMaximumFloat16(erf_approx, vneg_one); + MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign); + MlasStoreFloat16x8(&Output[i], result); + } + + for (; i < N; i++) { + float x = fp16_to_float(Input[i]); + float sign = (x < 0) ? -1.0f : 1.0f; + float absx = fabsf(x); + + if (absx > 4.0f) { + Output[i] = float_to_fp16(sign); + continue; + } + + float t = 1.0f / (1.0f + p * absx); + float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t; + float exp_neg_x2 = expf(-absx * absx); + float erf_approx = sign * (1.0f - poly * exp_neg_x2); + if (erf_approx > 1.0f) erf_approx = 1.0f; + if (erf_approx < -1.0f) erf_approx = -1.0f; + + Output[i] = float_to_fp16(erf_approx); + } +} +#endif diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.h b/onnxruntime/core/mlas/lib/erf_neon_fp16.h new file mode 100644 index 0000000000000..c8fd77e0e62d1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.h @@ -0,0 +1,24 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + erf_neon_fp16.h + +Abstract: + + This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" +#include "fp16_common.h" +#include "softmax_kernel_neon.h" + +using _mlas_fp16_ = uint16_t; +void MlasNeonErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N); diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index d4713cce5a176..31ad706fa26e2 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -54,6 +54,10 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); } + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); } @@ -78,6 +82,10 @@ MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); } + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } @@ -115,6 +123,13 @@ MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector)); } +MLAS_FORCEINLINE +void +MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector) +{ + vst1q_f16(Buffer, Vector); +} + MLAS_FORCEINLINE void MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) @@ -579,4 +594,39 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector) return vshl_n_s16(Vector, ShiftCount); } +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasReciprocalStepFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vrecpsq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector) +{ + return vrecpeq_f16(Vector); +} + +MLAS_FORCEINLINE +MLAS_UINT16X8 +MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vcltq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasAbsFloat16(MLAS_FLOAT16X8 Vector) +{ + return vabsq_f16(Vector); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vbslq_f16(Vector, Vector1, Vector2); +} + #endif // fp16 vector intrinsic supported diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp new file mode 100644 index 0000000000000..a78921821c5f9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -0,0 +1,51 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + Gelu.cpp + +Abstract: + + This module contains Gelu helper functions. + +--*/ + +#include "gelu.h" + +void +MLASCALL +MlasComputeFP16Gelu(const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo) +{ +#if defined(MLAS_USE_SVE) || defined(MLAS_NEON_INTRINSICS) + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + GetMlasPlatform().GeluF16KernelRoutine(input, output, temp, count, algo); + #endif +#else + (void)temp; + for (int64_t i = 0; i < count; ++i) { + float x = static_cast(input[i]); + float gelu_val; + + if (algo == "tanh") { + // GELU approximation (tanh) + const float B = 0.7978845608f; + const float C = 0.044715f * B; + float tanh_arg = x * (B + C * x * x); + float tanh_res = std::tanh(tanh_arg); + gelu_val = 0.5f * x * (1.0f + tanh_res); + } else { + // GELU exact (erf) + gelu_val = 0.5f * x * + (1.0f + std::erf(x * static_cast(M_SQRT1_2))); + } + + output[i] = MLAS_FP16(gelu_val); + } +#endif +} diff --git a/onnxruntime/core/mlas/lib/gelu.h b/onnxruntime/core/mlas/lib/gelu.h new file mode 100644 index 0000000000000..554afd5d41575 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.h @@ -0,0 +1,35 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + gelu.h + +Abstract: + + This module contains Gelu helper functions . + +--*/ + +#pragma once + +#include "fp16_common.h" +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +#include "erf_neon_fp16.h" + +void +MLASCALL +MlasNeonGeluF16Kernel( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + int64_t count, + const std::string& algo +); + +#endif + +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif diff --git a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp new file mode 100644 index 0000000000000..8802f9da7c987 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp @@ -0,0 +1,93 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + gelu_neon_fp16.cpp + +Abstract: + + This module contains Gelu helper functions . + +--*/ +#include "gelu.h" +#include +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + +void +MLASCALL +MlasNeonGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo) +{ + const float16_t v_half1 = 0.5f; + const float16_t v_one1 = 1.0f; + const float16_t v_sqrt1_21 = static_cast(M_SQRT1_2); + const float16_t v_B1 = 0.7978845608028654f; + const float16_t v_C1 = 0.035677408136300125f; + const float16_t c1 = 5.0f; + const float16_t c2 = -5.0f; + const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1); + const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1); + const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21); + const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1); + const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1); + + int64_t i = 0; + + if (algo == "tanh") { + // Preprocess input into temp[] for tanh + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2 + MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2) + tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1))); + MlasStoref16Float16x8(reinterpret_cast(temp + i), tanh_arg); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + float inner = x * (0.7978845608028654f + 0.035677408136300125f * x * x); + inner = std::max(-5.0f, std::min(5.0f, inner)); + temp[i] = static_cast(inner); + } + + // Tanh processing + MlasComputeTanh(temp, temp, count); + + } else{ + // Preprocess input into temp[] for erf + for (i = 0; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2); + MlasStoref16Float16x8(reinterpret_cast(temp + i), scaled); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + temp[i] = static_cast(x * static_cast(M_SQRT1_2)); + } + + // Erf processing + MlasNeonErfF16Kernel(reinterpret_cast(temp), reinterpret_cast<_mlas_fp16_*>(temp), count); + } + + // Final GELU output = 0.5 * x * (1 + tanh|erf) + i = 0; + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast(temp + i)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t))); + MlasStoref16Float16x8(reinterpret_cast(output + i), result); + } + + for (; i < count; ++i) { + float x = static_cast(input[i]); + float t = static_cast(temp[i]); + float gelu = 0.5f * x * (1.0f + t); + output[i] = static_cast(gelu); + } +} +#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e75ca3dc90e60..5a33467e7ce7f 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,6 +610,25 @@ void size_t N ); +using _mlas_fp16_ = uint16_t; +typedef +void +(MLASCALL MLAS_COMPUTE_ERF_FP16_KERNEL)( + const _mlas_fp16_* Input, + _mlas_fp16_* Output, + size_t N +); + +typedef +void +(MLASCALL MLAS_COMPUTE_GELU_FP16_KERNEL)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + MLAS_FP16* Temp, + int64_t N, + const std::string& Algo +); + typedef float (MLASCALL MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL)( @@ -1057,6 +1076,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; + MLAS_COMPUTE_ERF_FP16_KERNEL MlasNeonErfF16Kernel; + MLAS_COMPUTE_GELU_FP16_KERNEL MlasNeonGeluF16Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; @@ -1411,6 +1432,10 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + MLAS_COMPUTE_ERF_FP16_KERNEL* ErfF16KernelRoutine; + MLAS_COMPUTE_GELU_FP16_KERNEL* GeluF16KernelRoutine; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b913b1c3b8c26..69d1515eef145 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -19,6 +19,10 @@ Module Name: #ifdef MLAS_USE_SVE #include "sve/mlasi_sve.h" #endif +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +#include "erf_neon_fp16.h" +#include "gelu.h" +#endif #if defined(USE_KLEIDIAI) #include "kleidiai/mlasi_kleidiai.h" #endif @@ -635,6 +639,17 @@ Return Value: this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; } + + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + this->ErfF16KernelRoutine = MlasSveErfF16Kernel; + this->GeluF16KernelRoutine = MlasSveGeluF16Kernel; + } + else{ + this->ErfF16KernelRoutine = MlasNeonErfF16Kernel; + this->GeluF16KernelRoutine = MlasNeonGeluF16Kernel; + } + #endif #endif // diff --git a/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp new file mode 100644 index 0000000000000..573a72dc2117d --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp @@ -0,0 +1,250 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + Elementwise_sve_fp16.cpp + +Abstract: + + This module contains the SVE Elementwise functions . + +--*/ +#include "mlas_sve_fp16.h" + +struct MlasTanhConstants_fp16_scalar { + __fp16 LowerRange; + __fp16 UpperRange; + __fp16 alpha_7; + __fp16 alpha_5; + __fp16 alpha_3; + __fp16 alpha_1; + __fp16 beta_6; + __fp16 beta_4; + __fp16 beta_2; + __fp16 beta_0; +}; + +constexpr MlasTanhConstants_fp16_scalar TanhConstantsFp16 = { + -3.515625f, + 3.515625f, + 5.960464477539063e-08f, + 1.4841556549072266e-05f, + 0.000637054443359375f, + 0.004894256591796875f, + 1.1920928955078125e-06f, + 0.00011855363845825195f, + 0.0022678375244140625f, + 0.004894256591796875f +}; + +static inline MLAS_SVFLOAT16 +Tanh_Vector_SVE_fp16(MLAS_SVFLOAT16 x, MLAS_SVBOOL pg) +{ + MLAS_SVFLOAT16 g_LowerRange_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.LowerRange); + MLAS_SVFLOAT16 g_UpperRange_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.UpperRange); + MLAS_SVFLOAT16 g_alpha_7_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_7); + MLAS_SVFLOAT16 g_alpha_5_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_5); + MLAS_SVFLOAT16 g_alpha_3_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_3); + MLAS_SVFLOAT16 g_alpha_1_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_1); + MLAS_SVFLOAT16 g_beta_6_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_6); + MLAS_SVFLOAT16 g_beta_4_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_4); + MLAS_SVFLOAT16 g_beta_2_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_2); + MLAS_SVFLOAT16 g_beta_0_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_0); + + x = MlasSveMinfloat16(pg, x, g_UpperRange_vec); + x = MlasSveMaxfloat16(pg, x, g_LowerRange_vec); + + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, x, x); + MLAS_SVFLOAT16 p = MlasSveMLAfloat16(pg, g_alpha_5_vec, g_alpha_7_vec, x2); + p = MlasSveMLAfloat16(pg, g_alpha_3_vec, p, x2); + p = MlasSveMLAfloat16(pg, g_alpha_1_vec, p, x2); + p = MlasSveMulfloat16(pg, p, x); + + svfloat16_t q = MlasSveMLAfloat16(pg, g_beta_4_vec, g_beta_6_vec, x2); + q = MlasSveMLAfloat16(pg, g_beta_2_vec, q, x2); + q = MlasSveMLAfloat16(pg, g_beta_0_vec, q, x2); + + MLAS_SVFLOAT16 res = MlasSveDivfloat16(pg, p, q); + + return res; +} + +void +MlasSveTanhF16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) +{ + size_t offset = 0; + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + while (offset < N) { + MLAS_SVBOOL pg = MlasSveSelPredictefloat16(offset, N); + MLAS_SVFLOAT16 x = MlasSvereinterpretf16_u16(MlasSveLoadUint16(pg, &input[offset])); + MLAS_SVFLOAT16 y = Tanh_Vector_SVE_fp16(x, pg); + MlasSveStoreUint16(pg, &output[offset], MlasSvereinterpretu16_f16(y)); + offset += svcnth(); + } +} + +static inline MLAS_SVFLOAT16 +exp_neg_rational_approx_f16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x) +{ + const __fp16 a0 = 6.0f; + MLAS_SVFLOAT16 max_x = MlasSveBroadcastfloat16(a0); + x = MlasSveMinfloat16(pg, x, max_x); + + const __fp16 c0 = 1.330f; + const __fp16 c1 = -0.390f; + const __fp16 c2 = 0.0288f; + + const __fp16 d0 = 1.338f; + const __fp16 d1 = 0.848f; + const __fp16 d2 = 0.467f; + + MLAS_SVFLOAT16 c0v = MlasSveBroadcastfloat16(c0); + MLAS_SVFLOAT16 c1v = MlasSveBroadcastfloat16(c1); + MLAS_SVFLOAT16 c2v = MlasSveBroadcastfloat16(c2); + MLAS_SVFLOAT16 d0v = MlasSveBroadcastfloat16(d0); + MLAS_SVFLOAT16 d1v = MlasSveBroadcastfloat16(d1); + MLAS_SVFLOAT16 d2v = MlasSveBroadcastfloat16(d2); + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, x, x); + + MLAS_SVFLOAT16 num = MlasSveMLAfloat16(pg, c0v, c1v, x); + num = MlasSveMLAfloat16(pg, num, c2v, x2); + + MLAS_SVFLOAT16 den = MlasSveMLAfloat16(pg, d0v, d1v, x); + den = MlasSveMLAfloat16(pg, den, d2v, x2); + + MLAS_SVFLOAT16 recip = MlasSveReciprocalfloat16(den); + recip = MlasSveMulfloat16(pg, recip, MlasSveReciprocalStepfloat16(den, recip)); + recip = MlasSveMulfloat16(pg, recip, MlasSveReciprocalStepfloat16(den, recip)); + + MLAS_SVFLOAT16 result = MlasSveMulfloat16(pg, num, recip); + return result; +} + +void MLASCALL +MlasSveErfF16Kernel(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) +{ + const __fp16 p = 0.328f; + const __fp16 a1 = 0.2505f; + const __fp16 a2 = -0.2881f; + const __fp16 a3 = 1.4102f; + const __fp16 a4 = -1.423f; + const __fp16 a5 = 1.0547f; + + MLAS_SVFLOAT16 vp = MlasSveBroadcastfloat16(p); + MLAS_SVFLOAT16 va1 = MlasSveBroadcastfloat16(a1); + MLAS_SVFLOAT16 va2 = MlasSveBroadcastfloat16(a2); + MLAS_SVFLOAT16 va3 = MlasSveBroadcastfloat16(a3); + MLAS_SVFLOAT16 va4 = MlasSveBroadcastfloat16(a4); + MLAS_SVFLOAT16 va5 = MlasSveBroadcastfloat16(a5); + + const __fp16 v1 = 1.0f; + const __fp16 v2 = -1.0f; + const __fp16 v3 = 0.0f; + const __fp16 v4 = 4.0f; + MLAS_SVFLOAT16 vone = MlasSveBroadcastfloat16(v1); + MLAS_SVFLOAT16 vneg_one = MlasSveBroadcastfloat16(v2); + MLAS_SVFLOAT16 vzero = MlasSveBroadcastfloat16(v3); + MLAS_SVFLOAT16 vth = MlasSveBroadcastfloat16(v4); + + size_t i = 0; + while (i < N) { + MLAS_SVBOOL pg = MlasSveSelPredictefloat16(i, N); + MLAS_SVFLOAT16 x = MlasSvereinterpretf16_u16(MlasSveLoadUint16(pg, &Input[i])); + MLAS_SVBOOL neg_mask = MlasSveComparelessthanfloat16(pg, x, vzero); + MLAS_SVFLOAT16 sign = MlasSveSelectfloat16(neg_mask, vneg_one, vone); + MLAS_SVFLOAT16 absx = MlasSveAbsolutefloat16(MlasSveBroadcastfloat16(v3), pg, x); + svbool_t use_mask = MlasSveComparelessthanfloat16(pg, absx, vth); + MLAS_SVFLOAT16 absx_clamped = MlasSveMinfloat16(pg, absx, vth); + MLAS_SVFLOAT16 denom = MlasSveMLAfloat16(pg, vone, vp, absx_clamped); + MLAS_SVFLOAT16 t = MlasSveReciprocalfloat16(denom); + t = MlasSveMulfloat16(pg, t, MlasSveReciprocalStepfloat16(denom, t)); + t = MlasSveMulfloat16(pg, t, MlasSveReciprocalStepfloat16(denom, t)); + MLAS_SVFLOAT16 t2 = MlasSveMulfloat16(pg, t, t); + MLAS_SVFLOAT16 t3 = MlasSveMulfloat16(pg, t2, t); + MLAS_SVFLOAT16 t4 = MlasSveMulfloat16(pg, t3, t); + MLAS_SVFLOAT16 t5 = MlasSveMulfloat16(pg, t4, t); + svfloat16_t poly = MlasSveMulfloat16(pg, va1, t); + poly = MlasSveMLAfloat16(pg, poly, va2, t2); + poly = MlasSveMLAfloat16(pg, poly, va3, t3); + poly = MlasSveMLAfloat16(pg, poly, va4, t4); + poly = MlasSveMLAfloat16(pg, poly, va5, t5); + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, absx_clamped, absx_clamped); + MLAS_SVFLOAT16 exp_neg_x2 = exp_neg_rational_approx_f16(pg, x2); + MLAS_SVFLOAT16 poly_mul_exp = MlasSveMulfloat16(pg, poly, exp_neg_x2); + MLAS_SVFLOAT16 one_minus_term = MlasSveSubtractfloat16(pg, vone, poly_mul_exp); + MLAS_SVFLOAT16 erf_approx = MlasSveMulfloat16(pg, sign, one_minus_term); + erf_approx = MlasSveMinfloat16(pg, erf_approx, vone); + erf_approx = MlasSveMaxfloat16(pg, erf_approx, vneg_one); + MLAS_SVFLOAT16 result = MlasSveSelectfloat16(use_mask, erf_approx, sign); + MlasSveStoreUint16(pg, &Output[i], MlasSvereinterpretu16_f16(result)); + i += svcntp_b16(svptrue_b16(), pg); + } +} + +void MLASCALL +MlasSveGeluF16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo) +{ + const __fp16 r1 = 0.5f; + const __fp16 r2 = 1.0f; + const __fp16 r3 = static_cast(M_SQRT1_2); + const __fp16 r4 = 0.7978845608028654f; + const __fp16 r5 = 0.035677408136300125f; + + const MLAS_SVFLOAT16 v_half = MlasSveBroadcastfloat16(r1); + const MLAS_SVFLOAT16 v_one = MlasSveBroadcastfloat16(r2); + const MLAS_SVFLOAT16 v_sqrt1_2 = MlasSveBroadcastfloat16(r3); + const MLAS_SVFLOAT16 v_B = MlasSveBroadcastfloat16(r4); + const MLAS_SVFLOAT16 v_C = MlasSveBroadcastfloat16(r5); + + const __fp16 c1 = -5.0f; + const __fp16 c2 = 5.0f; + if (algo == "tanh") { + int64_t i = 0; + while (i < count) { + svbool_t pg = MlasSveSelPredictefloat16(i, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); + MLAS_SVFLOAT16 v_x2 = MlasSveMulfloat16(pg, v_x, v_x); + MLAS_SVFLOAT16 v_inner = MlasSveMLAfloat16(pg, v_B, v_C, v_x2); + MLAS_SVFLOAT16 v_tanh_arg = MlasSveMulfloat16(pg, v_x, v_inner); + v_tanh_arg = MlasSveMaxfloat16(pg, MlasSveBroadcastfloat16(c1), MlasSveMinfloat16(pg, v_tanh_arg, MlasSveBroadcastfloat16(c2))); + MlasSveStoreF16(pg, &temp[i], v_tanh_arg); + i += svcnth(); + } + + MlasSveTanhF16Kernel(reinterpret_cast(temp), reinterpret_cast(temp), count); + + int64_t j = 0; + while (j < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(j, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[j]); + MLAS_SVFLOAT16 v_tanh = MlasSveLoadFloat16(pg, &temp[j]); + MLAS_SVFLOAT16 v_result = MlasSveMulfloat16(pg, v_half, MlasSveMulfloat16(pg, v_x, svadd_f16_m(pg, v_one, v_tanh))); + MlasSveStoreF16(pg, &output[j], v_result); + j += svcnth(); + } + } else { + int64_t i = 0; + while (i < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(i, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); + MLAS_SVFLOAT16 v_scaled = MlasSveMulfloat16(pg, v_x, v_sqrt1_2); + MlasSveStoreF16(pg, &temp[i], v_scaled); + i += svcnth(); + } + + MlasSveErfF16Kernel(reinterpret_cast(temp), reinterpret_cast<_mlas_fp16_*>(temp), count); + + int64_t j = 0; + while (j < (count)) { + svbool_t pg = MlasSveSelPredictefloat16(j, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[j]); + MLAS_SVFLOAT16 v_erf = MlasSveLoadFloat16(pg, &temp[j]); + MLAS_SVFLOAT16 v_result = MlasSveMulfloat16(pg, v_half, MlasSveMulfloat16(pg, v_x, MlasSveAddfloat16(pg, v_one, v_erf))); + MlasSveStoreF16(pg, &output[j], v_result); + j += svcnth(); + } + } +} diff --git a/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h b/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h new file mode 100644 index 0000000000000..45379b8da1476 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h @@ -0,0 +1,182 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + mlas_sve_fp16.h + +Abstract: + + This module contains the procedure prototypes for the SVE FP16 intrinsics. + +--*/ + +#pragma once +#include +#include // for isnan if needed +#include + +#include "mlasi_sve.h" + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveBroadcastfloat16(__fp16 Value) +{ + return svdup_f16(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMinfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 range) +{ + return svmin_f16_m(pg, x, range); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMaxfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 range) +{ + return svmax_f16_m(pg, x, range); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMulfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svmul_f16_m(pg, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMLAfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y, MLAS_SVFLOAT16 z) +{ + return svmla_f16_m(pg, x, y, z); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveDivfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svdiv_f16_m(pg, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveSelPredictefloat16(size_t x, size_t y) +{ + return svwhilelt_b16(x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSvereinterpretf16_u16(MLAS_SVUINT16 x) +{ + return svreinterpret_f16_u16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT16 +MlasSveLoadUint16(MLAS_SVBOOL pg, const uint16_t* x) +{ + return svld1_u16(pg, x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveLoadFloat16(MLAS_SVBOOL pg, const MLAS_FP16* x) +{ + return svld1_f16(pg, reinterpret_cast(x)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreUint16(MLAS_SVBOOL pg, uint16_t* Buffer, MLAS_SVUINT16 Vector) +{ + return svst1_u16(pg, Buffer, Vector); +} +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreF16(MLAS_SVBOOL pg, MLAS_FP16* Buffer, MLAS_SVFLOAT16 Vector) +{ + return svst1_f16(pg, reinterpret_cast<__fp16*>(Buffer), Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT16 +MlasSvereinterpretu16_f16(MLAS_SVFLOAT16 x) +{ + return svreinterpret_u16_f16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveReciprocalfloat16(MLAS_SVFLOAT16 x) +{ + return svrecpe_f16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveReciprocalStepfloat16(MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svrecps_f16(x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveSelectfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svsel_f16(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveSubtractfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svsub_f16_m(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveComparelessthanfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svcmplt_f16(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveAbsolutefloat16(MLAS_SVFLOAT16 inactive, MLAS_SVBOOL Pred, MLAS_SVFLOAT16 y) +{ + return svabs_f16_m(inactive, Pred, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveAddfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svadd_f16_m(Pred, x, y); +} +#endif diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index 67a4bf453dd05..010b18b4be9ab 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -32,8 +32,37 @@ typedef svfloat32_t MLAS_SVFLOAT32; typedef svint32_t MLAS_SVINT32; typedef svuint32_t MLAS_SVUINT32; typedef svbool_t MLAS_SVBOOL; +typedef svfloat16_t MLAS_SVFLOAT16; +typedef svuint16_t MLAS_SVUINT16; -// function decarations +using _mlas_fp16_ = uint16_t; + +void +MLASCALL +MlasSveErfF16Kernel( + const _mlas_fp16_* Input, + _mlas_fp16_* Output, + size_t N +); + +void +MLASCALL +MlasSveTanhF16Kernel( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +void +MLASCALL +MlasSveGeluF16Kernel( + const MLAS_FP16* Input, + MLAS_FP16* Output, + MLAS_FP16* Temp, + int64_t N, + const std::string& Algo +); +// function declarations MLAS_FORCEINLINE MLAS_SVFLOAT32 MlasSveComputeExpVector( diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index 63bd744795535..7fb0336e527ae 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -22,7 +22,9 @@ Module Name: #include "mlasi.h" #include "softmax.h" - +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif // // Bundles the floating point constants for use by kernels written in assembly. // @@ -193,6 +195,13 @@ MlasComputeTanh( MLAS_FP16* Output, size_t N ) { +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + MlasSveTanhF16Kernel(Input, Output, N); + return; + } +#endif + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; if (dispatch == nullptr || dispatch->Tanh_Fp16 == nullptr) { MLAS_THROW_EX(std::runtime_error, "Tanh_Fp16 is not supported."); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 0ad8d1d4fef4d..72f5ddc7fd19a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1228,7 +1228,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); @@ -3269,7 +3270,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { IsNaN)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index b940d71e1165e..f4bf60bd9b6a8 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -2009,7 +2009,7 @@ Status Erf::Compute(OpKernelContext* context) const { const float* p_input = input_data + start; float* p_output = output_data + start; const std::ptrdiff_t count = std::min(length_per_task, elem_count - start); - MlasComputeErf(p_input, p_output, count); + MlasComputeErf(p_input, p_output, static_cast(count)); }, 0); @@ -2027,7 +2027,6 @@ Status Erf::Compute(OpKernelContext* context) const { int64_t elem_count = X->Shape().Size(); constexpr int64_t length_per_task = 4096; int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - const auto narrow_task_count = onnxruntime::narrow(task_count); // get allocator for temporary buffers @@ -2039,19 +2038,9 @@ Status Erf::Compute(OpKernelContext* context) const { [&](ptrdiff_t task_idx) { const auto start = task_idx * length_per_task; const int64_t count = std::min(length_per_task, elem_count - start); - const auto narrow_count = onnxruntime::narrow(count); - const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; - - // allocate temp buffers using ORT allocator - IAllocatorUniquePtr input_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - IAllocatorUniquePtr output_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - - // convert, compute, convert back - MlasConvertHalfToFloatBuffer(p_input, input_fp32.get(), narrow_count); - MlasComputeErf(input_fp32.get(), output_fp32.get(), narrow_count); - MlasConvertFloatToHalfBuffer(output_fp32.get(), p_output, narrow_count); + MlasComputeFP16Erf(p_input, p_output, static_cast(count)); }, 0); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..abf88392c2315 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -12,6 +12,31 @@ #include "core/providers/cpu/element_wise_ranged_transform.h" #include "core/providers/cpu/tensor/gelu.h" +#include +#include +#include + +#if defined(_WIN32) +#include +#endif + +inline void* AlignedAlloc(size_t alignment, size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + // std::aligned_alloc requires size to be a multiple of alignment + return std::aligned_alloc(alignment, size); +#endif +} + +inline void AlignedFree(void* p) { +#if defined(_WIN32) + _aligned_free(p); +#else + std::free(p); +#endif +} + using onnxruntime::narrow; using namespace onnxruntime::common; @@ -19,11 +44,17 @@ namespace onnxruntime { // May revisit the implementations to support inplace computation, if needed. -ONNX_CPU_OPERATOR_KERNEL( - Gelu, - 20, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); +#define ADD_TYPED_GELU_OP(data_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + Gelu, \ + 20, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gelu) + +ADD_TYPED_GELU_OP(float); +ADD_TYPED_GELU_OP(MLFloat16); #ifndef DISABLE_CONTRIB_OPS namespace contrib { @@ -46,9 +77,9 @@ Status Gelu::Compute(OpKernelContext* context) const { T* output_data = output->MutableData(); concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + size_t elem_count = input->Shape().Size(); + constexpr size_t length_per_task = 4096; // this number comes from FastGelu. + size_t task_count = (elem_count + length_per_task - 1) / length_per_task; if (approximation_algorithm_ == "tanh") { // FastGelu allows optional bias. Here we split input data into chunks. Each chunk @@ -64,16 +95,16 @@ Status Gelu::Compute(OpKernelContext* context) const { const auto start = task_idx * length_per_task; const T* p_input = input_data + start; T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); + size_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { T value = p_input[i]; p_output[i] = value * (static_cast(C) * value * value + static_cast(B)); } MlasComputeTanh(p_output, p_output, narrow(count)); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); } }, @@ -86,16 +117,16 @@ Status Gelu::Compute(OpKernelContext* context) const { const auto start = task_idx * length_per_task; const T* p_input = input_data + start; T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); + size_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { T value = p_input[i]; p_output[i] = value * static_cast(M_SQRT1_2); } MlasComputeErf(p_output, p_output, narrow(count)); - for (int64_t i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); } }, @@ -105,4 +136,54 @@ Status Gelu::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); } +template <> +Status Gelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const MLFloat16* input_data = input->Data(); + Tensor* output = context->Output(0, input->Shape()); + MLFloat16* output_data = output->MutableData(); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + size_t elem_count = input->Shape().Size(); + constexpr size_t length_per_task = 4096; + size_t task_count = (elem_count + length_per_task - 1) / length_per_task; + + if (approximation_algorithm_ != "tanh" && approximation_algorithm_ != "none") { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); + } + + // Alignment and buffer size for aligned_alloc + constexpr size_t alignment = 64; + + size_t buffer_size = elem_count * sizeof(MLFloat16); + size_t aligned_size = + ((buffer_size + alignment - 1) / alignment) * alignment; + + void* raw = AlignedAlloc(alignment, aligned_size); + if (!raw) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to allocate aligned temporary buffer."); + } + + auto deleter = [](MLFloat16* p) { + AlignedFree(p); + }; + + std::unique_ptr temp_fp16_aligned( + static_cast(raw), deleter); + + concurrency::ThreadPool::TryBatchParallelFor( + tp, + static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const MLFloat16* p_input = input_data + start; + MLFloat16* p_output = output_data + start; + size_t count = std::min(length_per_task, elem_count - start); + MLFloat16* p_temp = temp_fp16_aligned.get() + start; + MlasComputeFP16Gelu(p_input, p_output, p_temp, count, approximation_algorithm_); + }, + 0); + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d711e050fb913..b0a4ce3ed6599 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -752,6 +752,58 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { {}, {{"approximate", "tanh"}}, true, 20); } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +TEST_F(ActivationOpTest, Gelu_fp16_tanh) { + OpTester test("Gelu", 20); + auto formula = [](float x) { + return 0.5f * x * (1 + tanhf(0.7978845608028654f * (x + 0.044715f * x * x * x))); + }; + const std::vector X = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + std::vector Y; + Y.reserve(X.size()); + for (float x : X) { + Y.push_back(formula(x)); + } + std::vector dims{static_cast(X.size())}; + + std::vector f_X(X.size()); + std::vector f_Y(Y.size()); + ConvertFloatToMLFloat16(X.data(), f_X.data(), static_cast(X.size())); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), static_cast(Y.size())); + + test.AddInput("X", dims, f_X); + test.AddOutput("Y", dims, f_Y); + test.AddAttribute("approximate", "tanh"); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST_F(ActivationOpTest, Gelu_fp16_erf) { + OpTester test("Gelu", 20); + auto formula = [](float x) { + return static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2))); + }; + const std::vector X = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + std::vector Y; + Y.reserve(X.size()); + for (float x : X) { + Y.push_back(formula(x)); + } + std::vector dims{static_cast(X.size())}; + + std::vector f_X(X.size()); + std::vector f_Y(Y.size()); + ConvertFloatToMLFloat16(X.data(), f_X.data(), static_cast(X.size())); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), static_cast(Y.size())); + + test.AddInput("X", dims, f_X); + test.AddOutput("Y", dims, f_Y); + test.AddAttribute("approximate", "none"); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} +#endif #endif } // namespace test