Skip to content

Commit 040fe49

Browse files
committed
Address comments from reviewers
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
1 parent c54266b commit 040fe49

File tree

6 files changed

+84
-16
lines changed

6 files changed

+84
-16
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,17 @@ function (setup_arm_neon_nchwc)
327327
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h
328328
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp
329329
${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp
330-
# Hand written AArch64 micro-kernel for NCHW convolution. Using a
331-
# separate assembly file allows tighter control over register allocation
332-
# and avoids the overhead of C++/intrinsics based code generation.
333-
${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S
334-
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S
335-
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S
336330
)
331+
if(NOT WIN32)
332+
target_sources(onnxruntime_mlas PRIVATE
333+
# Hand written AArch64 micro-kernel for NCHW convolution. Using a
334+
# separate assembly file allows tighter control over register allocation
335+
# and avoids the overhead of C++/intrinsics based code generation.
336+
${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S
337+
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S
338+
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S
339+
)
340+
endif()
337341
list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC)
338342
set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE)
339343
endfunction ()
@@ -466,8 +470,6 @@ else()
466470
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S
467471
${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S
468472
${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S
469-
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S
470-
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S
471473
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
472474
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S
473475
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S

onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ SPDX-License-Identifier: MIT
44

55
Module Name:
66

7-
SconvDepthwiseFloatKernelNeon.S
7+
SconvDepthwiseKernelNeon.S
88

99
Abstract:
1010

@@ -18,7 +18,6 @@ Abstract:
1818
* When an output position touches padding, only the affected 4-wide
1919
lanes are checked individually and loaded; others are zeroed. This
2020
mirrors the behavior of the C++ helper LoadInputVectorWithBounds.
21-
mirrors the behaviour of the C++ helper LoadInputVectorWithBounds.
2221
* Keep the multiply/accumulate operations tightly scheduled to hide the
2322
load latency.
2423

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,15 +963,25 @@ extern "C" {
963963
MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd;
964964
#endif
965965
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
966+
// Intrinsics kernel for direct NCHW convolution
967+
MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon;
968+
#if !defined(_WIN32)
966969
// AArch64 assembly micro-kernel for direct NCHW convolution
967970
MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeonAsm;
971+
#endif
968972
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon;
973+
// Intrinsics kernel for depthwise NCHWc convolution
974+
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
975+
#if !defined(_WIN32)
969976
// AArch64 assembly micro-kernel for depthwise NCHWc convolution
970977
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeonAsm;
971-
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
978+
#endif
979+
// Intrinsics kernel for pointwise NCHWc convolution
980+
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;
981+
#if !defined(_WIN32)
972982
// AArch64 assembly micro-kernel for pointwise NCHWc convolution
973983
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeonAsm;
974-
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;
984+
#endif
975985
#if defined(__aarch64__) && defined(__linux__)
976986
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon;
977987
#endif

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,10 +571,15 @@ Return Value:
571571
this->EltwiseDispatch = &MlasEltwiseDispatchNeon;
572572

573573
#if defined(MLAS_USE_ARM_NEON_NCHWC)
574+
// Use the AArch64 assembly implementation on non-Windows platforms.
575+
#if !defined(_WIN32)
574576
// Prefer the hand written micro-kernel for the NCHW convolution path. It
575577
// offers a tighter schedule and a specialised two-output inner loop that
576578
// reduces pressure on the memory system compared to the generic kernel.
577579
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeonAsm;
580+
#else
581+
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon;
582+
#endif
578583
// Prefer the hand written AArch64 micro-kernel for pointwise convolution
579584
// as it computes multiple output positions at once and significantly
580585
// reduces memory traffic. The AArch64 assembly kernel is selected by

onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,58 @@ void
183183
}
184184

185185

186+
//
187+
// Implementation of MlasConvNchwFloatKernelNeon
188+
//
189+
190+
void
191+
MLASCALL
192+
MlasConvNchwFloatKernelNeon(
193+
const float* Input,
194+
const float* Filter,
195+
float* Output,
196+
size_t StrideWidth,
197+
size_t DilationWidth,
198+
size_t FilterCount,
199+
size_t InputStride,
200+
size_t FilterStride,
201+
size_t OutputStride,
202+
size_t KernelHeight,
203+
size_t KernelWidth,
204+
const float* InputBase,
205+
size_t InputWidth,
206+
size_t DilatedInputWidth,
207+
size_t OutputCountLeftPad,
208+
size_t OutputCount,
209+
size_t OutputCountRightPad,
210+
const float* Bias,
211+
unsigned KernelFlags
212+
)
213+
{
214+
MlasConvFloatKernelNeonImpl<false>(
215+
Input,
216+
Filter,
217+
Output,
218+
StrideWidth,
219+
DilationWidth,
220+
FilterCount,
221+
InputStride,
222+
FilterStride,
223+
OutputStride,
224+
KernelHeight,
225+
KernelWidth,
226+
InputBase,
227+
InputWidth,
228+
DilatedInputWidth,
229+
OutputCountLeftPad,
230+
OutputCount,
231+
OutputCountRightPad,
232+
Bias,
233+
KernelFlags
234+
);
235+
}
236+
237+
186238
//
187239
// Implementation of MlasConvNchwcFloatKernelNeon
188240
//

onnxruntime/core/mlas/lib/snchwc.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM
882882

883883
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC))
884884
MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel;
885-
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
885+
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32)
886886
MLAS_CONV_POINTWISE_FLOAT_KERNEL* const KernelFast = MlasConvPointwiseFloatKernelNeonAsm;
887887
#endif
888888
#if defined(__aarch64__) && defined(__linux__)
@@ -940,7 +940,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM
940940
//
941941

942942
MLAS_CONV_POINTWISE_FLOAT_KERNEL* KernelToUse = Kernel;
943-
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
943+
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32)
944944
if (!WorkBlock->UseBf16 && OutputThisIteration >= 4 &&
945945
StrideHeight == 1 && StrideWidth == 1) {
946946
KernelToUse = KernelFast;
@@ -1034,7 +1034,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM
10341034

10351035
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC))
10361036
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel;
1037-
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
1037+
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32)
10381038
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* const KernelFast = MlasConvDepthwiseFloatKernelNeonAsm;
10391039
#endif
10401040
#else
@@ -1061,7 +1061,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM
10611061
//
10621062

10631063
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* KernelToUse = Kernel;
1064-
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
1064+
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32)
10651065
if (OutputWidth >= 4) {
10661066
KernelToUse = KernelFast;
10671067
}

0 commit comments

Comments
 (0)