@@ -23,9 +23,11 @@ Module Name:
2323
2424#include " mlasi.h"
2525
26+ // Common implementation for NCHW and NCHWC convolution kernels
27+ template <bool IsNchwcFormat>
2628void
2729 MLASCALL
28- MlasConvNchwFloatKernelNeon (
30+ MlasConvFloatKernelNeonImpl (
2931 const float * Input,
3032 const float * Filter,
3133 float * Output,
9092 const float * input_base = Input + output_idx * StrideWidthElements +
9193 kh * DilatedInputWidthElements + kw * DilationWidthElements;
9294
93- const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
94- const float * input_row_end = input_row_start + InputWidthElements;
95+ if (IsNchwcFormat) {
96+ // NCHWC format - process each element in the block
97+ for (size_t filterBlock = 0 ; filterBlock < BlockSize; filterBlock++) {
98+ const float * input_element = input_base + filterBlock;
99+ const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
100+ const float * input_row_end = input_row_start + InputWidthElements;
101+
102+ float input_value;
103+ if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
104+ input_value = *input_element;
105+ } else {
106+ input_value = 0 .0f ;
107+ }
108+
109+ const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value);
110+
111+ size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
112+ kw * (BlockSize * BlockSize) +
113+ filterBlock * BlockSize;
95114
96- float input_value;
97- if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) {
98- input_value = *input_base;
115+ const float32x4_t FilterVector = MlasLoadFloat32x4 (&filter[kernel_base_pos]);
116+
117+ Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
118+ }
99119 } else {
100- input_value = 0 .0f ;
101- }
120+ // NCHW format - simpler processing
121+ const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
122+ const float * input_row_end = input_row_start + InputWidthElements;
102123
103- const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value);
124+ float input_value;
125+ if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) {
126+ input_value = *input_base;
127+ } else {
128+ input_value = 0 .0f ;
129+ }
104130
105- size_t kernel_base_pos = kh * KernelWidth + kw ;
131+ const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value) ;
106132
107- const float32x4_t FilterVector = MlasLoadFloat32x4 (&filter[ kernel_base_pos * BlockSize]) ;
133+ size_t kernel_base_pos = kh * KernelWidth + kw ;
108134
109- Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
135+ const float32x4_t FilterVector = MlasLoadFloat32x4 (&filter[kernel_base_pos * BlockSize]);
136+
137+ Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
138+ }
110139 }
111140 }
112141
@@ -119,6 +148,53 @@ void
119148 }
120149}
121150
151+ void
152+ MLASCALL
153+ MlasConvNchwFloatKernelNeon (
154+ const float * Input,
155+ const float * Filter,
156+ float * Output,
157+ size_t StrideWidth,
158+ size_t DilationWidth,
159+ size_t FilterCount,
160+ size_t InputStride,
161+ size_t FilterStride,
162+ size_t OutputStride,
163+ size_t KernelHeight,
164+ size_t KernelWidth,
165+ const float * InputBase,
166+ size_t InputWidth,
167+ size_t DilatedInputWidth,
168+ size_t OutputCountLeftPad,
169+ size_t OutputCount,
170+ size_t OutputCountRightPad,
171+ const float * Bias,
172+ unsigned KernelFlags
173+ )
174+ {
175+ MlasConvFloatKernelNeonImpl<false >(
176+ Input,
177+ Filter,
178+ Output,
179+ StrideWidth,
180+ DilationWidth,
181+ FilterCount,
182+ InputStride,
183+ FilterStride,
184+ OutputStride,
185+ KernelHeight,
186+ KernelWidth,
187+ InputBase,
188+ InputWidth,
189+ DilatedInputWidth,
190+ OutputCountLeftPad,
191+ OutputCount,
192+ OutputCountRightPad,
193+ Bias,
194+ KernelFlags
195+ );
196+ }
197+
122198//
123199// Implementation of MlasConvNchwcFloatKernelNeon
124200//
@@ -147,81 +223,27 @@ void
147223 unsigned KernelFlags
148224 )
149225{
150- const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0 ;
151- const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0 ;
152- const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
153-
154- const size_t BlockSize = MlasNchwcGetBlockSize ();
155- const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
156-
157- const size_t StrideWidthElements = StrideWidth / sizeof (float );
158- const size_t DilationWidthElements = DilationWidth / sizeof (float );
159- const size_t FilterStrideElements = FilterStride / sizeof (float );
160- const size_t OutputStrideElements = OutputStride / sizeof (float );
161- const size_t InputWidthElements = InputWidth / sizeof (float );
162- const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof (float );
163-
164- (void )InputStride;
165-
166- const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
167-
168- for (size_t output_idx = 0 ; output_idx < TotalOutputCount; output_idx++) {
169- bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount);
170-
171- for (size_t filterSetBlock = 0 ; filterSetBlock < FilterCount; filterSetBlock++) {
172- const float * filter = Filter + filterSetBlock * FilterStrideElements;
173- float * output = Output + filterSetBlock * OutputStrideElements;
174-
175- float32x4_t Accumulator;
176-
177- if (AccumulateOutput) {
178- Accumulator = MlasLoadFloat32x4 (&output[output_idx * BlockSize]);
179- } else {
180- Accumulator = MlasBroadcastFloat32x4 (0 .0f );
181- }
182-
183- if (BiasAddition) {
184- const float32x4_t BiasVector = MlasLoadFloat32x4 (&Bias[filterSetBlock * BlockSize]);
185- Accumulator = MlasAddFloat32x4 (Accumulator, BiasVector);
186- }
187-
188- for (size_t kh = 0 ; kh < KernelHeight; kh++) {
189- for (size_t kw = 0 ; kw < KernelWidth; kw++) {
190- const float * input_base = Input + output_idx * StrideWidthElements +
191- kh * DilatedInputWidthElements + kw * DilationWidthElements;
192-
193- for (size_t filterBlock = 0 ; filterBlock < BlockSize; filterBlock++) {
194- const float * input_element = input_base + filterBlock;
195- const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
196- const float * input_row_end = input_row_start + InputWidthElements;
197-
198- float input_value;
199- if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
200- input_value = *input_element;
201- } else {
202- input_value = 0 .0f ;
203- }
204-
205- const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value);
206-
207- size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
208- kw * (BlockSize * BlockSize) +
209- filterBlock * BlockSize;
210-
211- const float32x4_t FilterVector = MlasLoadFloat32x4 (&filter[kernel_base_pos]);
212-
213- Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
214- }
215- }
216- }
217-
218- if (ReluActivation) {
219- Accumulator = MlasMaximumFloat32x4 (Accumulator, ZeroVector);
220- }
221-
222- MlasStoreFloat32x4 (&output[output_idx * BlockSize], Accumulator);
223- }
224- }
226+ MlasConvFloatKernelNeonImpl<true >(
227+ Input,
228+ Filter,
229+ Output,
230+ StrideWidth,
231+ DilationWidth,
232+ FilterCount,
233+ InputStride,
234+ FilterStride,
235+ OutputStride,
236+ KernelHeight,
237+ KernelWidth,
238+ InputBase,
239+ InputWidth,
240+ DilatedInputWidth,
241+ OutputCountLeftPad,
242+ OutputCount,
243+ OutputCountRightPad,
244+ Bias,
245+ KernelFlags
246+ );
225247}
226248
227249//
0 commit comments