Skip to content

Commit 333ad49

Browse files
hariharans29github-actions[bot]Copilot
authored andcommitted
[MLAS/NEON] Add dedicated kernel for depthwise convolution for ARM64 using NEON intrinsics (microsoft#26688)
### Description **Motivation and approach taken:** Add a dedicated depthwise convolution kernel for the most common depthwise convolution configuration (3x3 filter, stride = 1, pad <= 1, dilation = 1) using NEON intrinsics. This does significantly better than the current approach of `Im2Col + SGemm`. The Im2Col step extracts convolution patches and this is a wasteful step and for a 3x3 filter, K would be 9 for the SGemm and usually Gemms are not optimized for such small `K` values. Hence, a dedicated kernel works much better. Initially, I ported over the Winograd based NEON accelerated depthwise convolution kernel from PyTorch but I found that its performance is not very good. It's poor performance is probably due to applying the Winograd transformation for the filter repeatedly. A better approach may be to tranform the filter offline and this approach can be considered for later (I reverted the PyTorch Winograd implementation in this commit: microsoft@2820a84). The current depthwise kernel added in this PR was authored by GPT5.1-Codex and with some minor bug fixes it seems to be functionally correct now and also provides the perf boost we are seeking. **Unit tests:** There are already depthwise convolution tests already existing in the codebase. I don't see a need for new ones at this point. **Kernel benchmarking:** This is the kernel level perf improvement from MLAS Conv benchmarks (About 50% kernel latency improvements): <img width="1055" height="90" alt="image" src="https://github.com/user-attachments/assets/ead9eb83-2d62-4157-a065-70c67c8c7517" /> ### Motivation and Context A key customer model had a few depthwise conolution operations and this change provides a **non-negligible ~3% throughput improvement** using the customer provided benchmarking setup For those interested, microsoft#26654 adds support for the same type of convolution variant but that leverages SME1/SME2 through KleidiAI. This PR is conceptually the same but targeting NEON only platforms. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e74cd86 commit 333ad49

File tree

9 files changed

+447
-18
lines changed

9 files changed

+447
-18
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ function(setup_mlas_source_for_windows)
115115
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
116116
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
117117
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
118+
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
118119
)
119120

120121
set(mlas_platform_preprocess_srcs
@@ -310,9 +311,9 @@ endfunction()
310311

311312
function (setup_arm_neon_nchwc)
312313
target_sources(onnxruntime_mlas PRIVATE
313-
${MLAS_SRC_DIR}/sconv.h
314-
${MLAS_SRC_DIR}/sconv_kernel_neon.cpp
315-
${MLAS_SRC_DIR}/spool_kernel_neon.cpp
314+
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h
315+
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp
316+
${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp
316317
)
317318
list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC)
318319
set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE)
@@ -466,6 +467,7 @@ else()
466467
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
467468
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
468469
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
470+
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
469471
)
470472

471473
# Conditionally add the SVE implementation if compiler supports it

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ enum MLAS_CONV_ALGORITHM {
854854
MlasConvAlgorithmGemmDirect,
855855
MlasConvAlgorithmExpandThenGemm,
856856
MlasConvAlgorithmExpandThenGemmSegmented,
857-
#if defined(MLAS_TARGET_WASM_SCALAR)
857+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
858858
MlasConvAlgorithmDepthwise,
859859
#endif
860860
};

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 135 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -698,13 +698,13 @@ Return Value:
698698
const size_t OutputGroupSize = FilterCount * OutputSize;
699699
const size_t FilterGroupSize = FilterCount * K;
700700

701+
const float* input = WorkBlock->Input + BatchGroupStart * InputGroupSize;
702+
float* output = WorkBlock->Output + BatchGroupStart * OutputGroupSize;
703+
701704
for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
702705

703706
size_t group = bg % GroupCount;
704-
705-
const float* input = WorkBlock->Input + bg * InputGroupSize;
706707
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
707-
float* output = WorkBlock->Output + bg * OutputGroupSize;
708708

709709
//
710710
// Invoke the non-threaded GEMM directly with the input tensor.
@@ -726,6 +726,9 @@ Return Value:
726726

727727
MlasActivation(Parameters->Activation, output, bias, FilterCount,
728728
OutputSize, OutputSize);
729+
730+
input += InputGroupSize;
731+
output += OutputGroupSize;
729732
}
730733
}
731734

@@ -805,6 +808,90 @@ Return Value:
805808
}
806809
}
807810

811+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
812+
813+
void
814+
MlasDepthwiseThreaded(
815+
void* Context,
816+
ptrdiff_t Index
817+
)
818+
819+
/*++
820+
821+
Routine Description:
822+
823+
This routine is invoked from a worker thread to execute a segment of a
824+
convolution operation.
825+
826+
If using this, the entire convolution operation is parallelized on the
827+
(batch size * group count) parameter and this routine has logic to
828+
perform a specific thread's shard of the entire Convolution operation.
829+
830+
Arguments:
831+
832+
Context - Supplies the pointer to the context for the threaded operation.
833+
834+
Index - Supplies the current index of the threaded operation.
835+
836+
Return Value:
837+
838+
None.
839+
840+
--*/
841+
842+
{
843+
844+
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
845+
846+
const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
847+
848+
const size_t GroupCount = Parameters->GroupCount;
849+
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;
850+
851+
const size_t TargetThreadCount = WorkBlock->TargetThreadCount;
852+
853+
const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount;
854+
const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount;
855+
856+
size_t BatchGroupStart;
857+
size_t BatchGroupEnd;
858+
859+
if (static_cast<size_t>(Index) < BatchGroupCountExtra) {
860+
BatchGroupStart = (BatchGroupCountPerThread + 1) * Index;
861+
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1;
862+
} else {
863+
BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra;
864+
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread;
865+
}
866+
867+
const size_t FilterCount = Parameters->FilterCount;
868+
const size_t OutputSize = Parameters->OutputSize;
869+
const size_t K = Parameters->K;
870+
871+
const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
872+
const size_t OutputGroupSize = FilterCount * OutputSize;
873+
const size_t FilterGroupSize = FilterCount * K;
874+
875+
for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
876+
size_t group = bg % GroupCount;
877+
878+
const float* input = WorkBlock->Input + bg * InputGroupSize;
879+
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
880+
float* output = WorkBlock->Output + bg * OutputGroupSize;
881+
const float* bias = WorkBlock->Bias;
882+
if (bias != nullptr) {
883+
bias += group * FilterCount;
884+
}
885+
886+
float* WorkingBuffer = WorkBlock->WorkingBuffer;
887+
888+
MlasConvDepthwiseFloat_CHW(Parameters, input, filter, output, WorkingBuffer);
889+
MlasActivation(Parameters->Activation, output, bias, FilterCount, OutputSize, OutputSize);
890+
}
891+
}
892+
893+
#endif
894+
808895
inline
809896
bool
810897
MlasConvTryMultithread(
@@ -985,7 +1072,7 @@ Return Value:
9851072
return;
9861073
}
9871074

988-
#if defined(MLAS_TARGET_WASM_SCALAR)
1075+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
9891076

9901077
if (Algorithm == MlasConvAlgorithmDepthwise) {
9911078
// Fill the Working Buffer with Zero for use by the depthwise kernel.
@@ -1019,6 +1106,35 @@ Return Value:
10191106
return;
10201107
}
10211108

1109+
1110+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
1111+
1112+
if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1) || (GroupCount > 1))) {
1113+
const size_t BatchGroupCount = BatchCount * GroupCount;
1114+
1115+
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
1116+
1117+
if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
1118+
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
1119+
}
1120+
1121+
MLAS_CONV_WORK_BLOCK WorkBlock;
1122+
1123+
WorkBlock.Parameters = Parameters;
1124+
WorkBlock.Input = Input;
1125+
WorkBlock.Filter = Filter;
1126+
WorkBlock.Bias = Bias;
1127+
WorkBlock.WorkingBuffer = WorkingBuffer;
1128+
WorkBlock.Output = Output;
1129+
WorkBlock.TargetThreadCount = TargetThreadCount;
1130+
1131+
MlasExecuteThreaded(MlasDepthwiseThreaded, &WorkBlock, TargetThreadCount, ThreadPool);
1132+
1133+
return;
1134+
}
1135+
1136+
#endif
1137+
10221138
//
10231139
// Iterate over each batch and group.
10241140
//
@@ -1082,7 +1198,7 @@ Return Value:
10821198
break;
10831199
}
10841200

1085-
#if defined(MLAS_TARGET_WASM_SCALAR)
1201+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
10861202

10871203
case MlasConvAlgorithmDepthwise:
10881204
{
@@ -1337,17 +1453,26 @@ Return Value:
13371453

13381454
} else {
13391455

1340-
#if defined(MLAS_TARGET_WASM_SCALAR)
1456+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
13411457

1342-
// Scalar direct conv for depthwise convolution.
1343-
// Currently only support 3x3 kernel with padding <=1 and dilations = 1.
1458+
// Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution.
1459+
// Currently only support 3x3 kernel with padding <=1 and dilations = 1
1460+
// and on ARM64, it is further restricted to strides = 1.
13441461
// TODO: support more general depthwise convolution.
13451462

1463+
// On ARM64, only support stride = 1 for depthwise conv.
1464+
#if defined(MLAS_TARGET_ARM64)
1465+
bool depthwise_conv_stride_support_check = Parameters->StrideShape[0] == 1 && Parameters->StrideShape[1] == 1;
1466+
#else
1467+
bool depthwise_conv_stride_support_check = true;
1468+
#endif
1469+
13461470
if (Dimensions == 2
13471471
&& Parameters->FilterCount == 1 && Parameters->InputChannels == 1
13481472
&& Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3
13491473
&& Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1
13501474
&& Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1
1475+
&& depthwise_conv_stride_support_check
13511476
&& Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) {
13521477

13531478
*WorkingBufferSize = Parameters->InputShape[1] + 2;
@@ -1411,8 +1536,8 @@ Return Value:
14111536

14121537
if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {
14131538

1414-
size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
1415-
Parameters->FilterCount * Parameters->OutputSize,
1539+
size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
1540+
Parameters->FilterCount * Parameters->OutputSize,
14161541
static_cast<size_t>(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)});
14171542
TargetThreadCount = MaximumThreadCount;
14181543
if (static_cast<size_t>(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) {

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1645,7 +1645,8 @@ MlasFp32FromBits(
16451645
#pragma warning(pop)
16461646
#endif
16471647

1648-
#if defined(MLAS_TARGET_WASM_SCALAR)
1648+
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
1649+
16491650

16501651
void
16511652
MLASCALL

0 commit comments

Comments
 (0)