Skip to content

Commit bd38b0e

Browse files
committed
Move comment to more appropriate place
1 parent 71fa09f commit bd38b0e

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,6 @@ Return Value:
580580
#else
581581
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon;
582582
#endif
583-
// Prefer the hand written AArch64 micro-kernel for pointwise convolution
584-
// as it computes multiple output positions at once and significantly
585-
// reduces memory traffic. The AArch64 assembly kernel is selected by
586-
// heuristics in snchwc.cpp to avoid regressions on small convolutions, so
587-
// we set the default to the intrinsics version here.
588583
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
589584
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
590585
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;

onnxruntime/core/mlas/lib/snchwc.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,8 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM
883883
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC))
884884
MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel;
885885
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32)
886+
// AArch64 assembly kernel pointwise convolution computes multiple
887+
// output positions at once and significantly reduces memory traffic.
886888
MLAS_CONV_POINTWISE_FLOAT_KERNEL* const KernelFast = MlasConvPointwiseFloatKernelNeonAsm;
887889
#endif
888890
#if defined(__aarch64__) && defined(__linux__)
@@ -941,6 +943,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM
941943

942944
MLAS_CONV_POINTWISE_FLOAT_KERNEL* KernelToUse = Kernel;
943945
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32)
946+
// Heuristically select the AArch64 assembly kernel for larger convolutions
944947
if (!WorkBlock->UseBf16 && OutputThisIteration >= 4 &&
945948
StrideHeight == 1 && StrideWidth == 1) {
946949
KernelToUse = KernelFast;

0 commit comments

Comments
 (0)