Skip to content

Commit b92dd45

Browse files
Refactor to share some code
1 parent 382e61b commit b92dd45

File tree

1 file changed

+109
-87
lines changed

1 file changed

+109
-87
lines changed

onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp

Lines changed: 109 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ Module Name:
2323

2424
#include "mlasi.h"
2525

26+
// Common implementation for NCHW and NCHWC convolution kernels
27+
template <bool IsNchwcFormat>
2628
void
2729
MLASCALL
28-
MlasConvNchwFloatKernelNeon(
30+
MlasConvFloatKernelNeonImpl(
2931
const float* Input,
3032
const float* Filter,
3133
float* Output,
@@ -90,23 +92,50 @@ void
9092
const float* input_base = Input + output_idx * StrideWidthElements +
9193
kh * DilatedInputWidthElements + kw * DilationWidthElements;
9294

93-
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
94-
const float* input_row_end = input_row_start + InputWidthElements;
95+
if (IsNchwcFormat) {
96+
// NCHWC format - process each element in the block
97+
for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) {
98+
const float* input_element = input_base + filterBlock;
99+
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
100+
const float* input_row_end = input_row_start + InputWidthElements;
101+
102+
float input_value;
103+
if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
104+
input_value = *input_element;
105+
} else {
106+
input_value = 0.0f;
107+
}
108+
109+
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
110+
111+
size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
112+
kw * (BlockSize * BlockSize) +
113+
filterBlock * BlockSize;
95114

96-
float input_value;
97-
if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) {
98-
input_value = *input_base;
115+
const float32x4_t FilterVector = MlasLoadFloat32x4(&filter[kernel_base_pos]);
116+
117+
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
118+
}
99119
} else {
100-
input_value = 0.0f;
101-
}
120+
// NCHW format - simpler processing
121+
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
122+
const float* input_row_end = input_row_start + InputWidthElements;
102123

103-
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
124+
float input_value;
125+
if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) {
126+
input_value = *input_base;
127+
} else {
128+
input_value = 0.0f;
129+
}
104130

105-
size_t kernel_base_pos = kh * KernelWidth + kw;
131+
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
106132

107-
const float32x4_t FilterVector = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize]);
133+
size_t kernel_base_pos = kh * KernelWidth + kw;
108134

109-
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
135+
const float32x4_t FilterVector = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize]);
136+
137+
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
138+
}
110139
}
111140
}
112141

@@ -119,6 +148,53 @@ void
119148
}
120149
}
121150

151+
void
152+
MLASCALL
153+
MlasConvNchwFloatKernelNeon(
154+
const float* Input,
155+
const float* Filter,
156+
float* Output,
157+
size_t StrideWidth,
158+
size_t DilationWidth,
159+
size_t FilterCount,
160+
size_t InputStride,
161+
size_t FilterStride,
162+
size_t OutputStride,
163+
size_t KernelHeight,
164+
size_t KernelWidth,
165+
const float* InputBase,
166+
size_t InputWidth,
167+
size_t DilatedInputWidth,
168+
size_t OutputCountLeftPad,
169+
size_t OutputCount,
170+
size_t OutputCountRightPad,
171+
const float* Bias,
172+
unsigned KernelFlags
173+
)
174+
{
175+
MlasConvFloatKernelNeonImpl<false>(
176+
Input,
177+
Filter,
178+
Output,
179+
StrideWidth,
180+
DilationWidth,
181+
FilterCount,
182+
InputStride,
183+
FilterStride,
184+
OutputStride,
185+
KernelHeight,
186+
KernelWidth,
187+
InputBase,
188+
InputWidth,
189+
DilatedInputWidth,
190+
OutputCountLeftPad,
191+
OutputCount,
192+
OutputCountRightPad,
193+
Bias,
194+
KernelFlags
195+
);
196+
}
197+
122198
//
123199
// Implementation of MlasConvNchwcFloatKernelNeon
124200
//
@@ -147,81 +223,27 @@ void
147223
unsigned KernelFlags
148224
)
149225
{
150-
const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0;
151-
const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0;
152-
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
153-
154-
const size_t BlockSize = MlasNchwcGetBlockSize();
155-
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
156-
157-
const size_t StrideWidthElements = StrideWidth / sizeof(float);
158-
const size_t DilationWidthElements = DilationWidth / sizeof(float);
159-
const size_t FilterStrideElements = FilterStride / sizeof(float);
160-
const size_t OutputStrideElements = OutputStride / sizeof(float);
161-
const size_t InputWidthElements = InputWidth / sizeof(float);
162-
const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float);
163-
164-
(void)InputStride;
165-
166-
const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
167-
168-
for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) {
169-
bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount);
170-
171-
for (size_t filterSetBlock = 0; filterSetBlock < FilterCount; filterSetBlock++) {
172-
const float* filter = Filter + filterSetBlock * FilterStrideElements;
173-
float* output = Output + filterSetBlock * OutputStrideElements;
174-
175-
float32x4_t Accumulator;
176-
177-
if (AccumulateOutput) {
178-
Accumulator = MlasLoadFloat32x4(&output[output_idx * BlockSize]);
179-
} else {
180-
Accumulator = MlasBroadcastFloat32x4(0.0f);
181-
}
182-
183-
if (BiasAddition) {
184-
const float32x4_t BiasVector = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]);
185-
Accumulator = MlasAddFloat32x4(Accumulator, BiasVector);
186-
}
187-
188-
for (size_t kh = 0; kh < KernelHeight; kh++) {
189-
for (size_t kw = 0; kw < KernelWidth; kw++) {
190-
const float* input_base = Input + output_idx * StrideWidthElements +
191-
kh * DilatedInputWidthElements + kw * DilationWidthElements;
192-
193-
for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) {
194-
const float* input_element = input_base + filterBlock;
195-
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
196-
const float* input_row_end = input_row_start + InputWidthElements;
197-
198-
float input_value;
199-
if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
200-
input_value = *input_element;
201-
} else {
202-
input_value = 0.0f;
203-
}
204-
205-
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
206-
207-
size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
208-
kw * (BlockSize * BlockSize) +
209-
filterBlock * BlockSize;
210-
211-
const float32x4_t FilterVector = MlasLoadFloat32x4(&filter[kernel_base_pos]);
212-
213-
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
214-
}
215-
}
216-
}
217-
218-
if (ReluActivation) {
219-
Accumulator = MlasMaximumFloat32x4(Accumulator, ZeroVector);
220-
}
221-
222-
MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator);
223-
}
224-
}
226+
MlasConvFloatKernelNeonImpl<true>(
227+
Input,
228+
Filter,
229+
Output,
230+
StrideWidth,
231+
DilationWidth,
232+
FilterCount,
233+
InputStride,
234+
FilterStride,
235+
OutputStride,
236+
KernelHeight,
237+
KernelWidth,
238+
InputBase,
239+
InputWidth,
240+
DilatedInputWidth,
241+
OutputCountLeftPad,
242+
OutputCount,
243+
OutputCountRightPad,
244+
Bias,
245+
KernelFlags
246+
);
225247
}
226248

227249
//

0 commit comments

Comments
 (0)