@@ -21,7 +21,6 @@ Module Name:
2121#include < algorithm>
2222#include < cstddef>
2323
24- #include " arm_neon.h"
2524#include " mlasi.h"
2625
2726void
5352 const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
5453
5554 const size_t BlockSize = MlasNchwcGetBlockSize ();
56- const float32x4_t ZeroVector = vdupq_n_f32 (0 .0f );
55+ const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
5756
5857 const size_t StrideWidthElements = StrideWidth / sizeof (float );
5958 const size_t DilationWidthElements = DilationWidth / sizeof (float );
7877 if (AccumulateOutput) {
7978 Accumulator = MlasLoadFloat32x4 (&output[output_idx * BlockSize]);
8079 } else {
81- Accumulator = vdupq_n_f32 (0 .0f );
80+ Accumulator = MlasBroadcastFloat32x4 (0 .0f );
8281 }
8382
8483 if (BiasAddition) {
8584 const float32x4_t BiasVector = MlasLoadFloat32x4 (&Bias[filterSetBlock * BlockSize]);
86- Accumulator = vaddq_f32 (Accumulator, BiasVector);
85+ Accumulator = MlasAddFloat32x4 (Accumulator, BiasVector);
8786 }
8887
8988 for (size_t kh = 0 ; kh < KernelHeight; kh++) {
101100 input_value = 0 .0f ;
102101 }
103102
104- const float32x4_t InputVector = vdupq_n_f32 (input_value);
103+ const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value);
105104
106105 size_t kernel_base_pos = kh * KernelWidth + kw;
107106
153152 const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
154153
155154 const size_t BlockSize = MlasNchwcGetBlockSize ();
156- const float32x4_t ZeroVector = vdupq_n_f32 (0 .0f );
155+ const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
157156
158157 const size_t StrideWidthElements = StrideWidth / sizeof (float );
159158 const size_t DilationWidthElements = DilationWidth / sizeof (float );
@@ -178,12 +177,12 @@ void
178177 if (AccumulateOutput) {
179178 Accumulator = MlasLoadFloat32x4 (&output[output_idx * BlockSize]);
180179 } else {
181- Accumulator = vdupq_n_f32 (0 .0f );
180+ Accumulator = MlasBroadcastFloat32x4 (0 .0f );
182181 }
183182
184183 if (BiasAddition) {
185184 const float32x4_t BiasVector = MlasLoadFloat32x4 (&Bias[filterSetBlock * BlockSize]);
186- Accumulator = vaddq_f32 (Accumulator, BiasVector);
185+ Accumulator = MlasAddFloat32x4 (Accumulator, BiasVector);
187186 }
188187
189188 for (size_t kh = 0 ; kh < KernelHeight; kh++) {
203202 input_value = 0 .0f ;
204203 }
205204
206- const float32x4_t InputVector = vdupq_n_f32 (input_value);
205+ const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value);
207206
208207 size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
209208 kw * (BlockSize * BlockSize) +
259258 const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
260259
261260 const size_t BlockSize = MlasNchwcGetBlockSize ();
262- const float32x4_t ZeroVector = vdupq_n_f32 (0 .0f );
261+ const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
263262
264263 const size_t StrideWidthElements = StrideWidth / sizeof (float );
265264 const size_t DilationWidthElements = DilationWidth / sizeof (float );
@@ -279,10 +278,11 @@ void
279278
280279 if (AccumulateOutput) {
281280 Accumulator = MlasLoadFloat32x4 (&Output[output_idx * BlockSize]);
282- } else if (BiasAddition) {
283- Accumulator = MlasLoadFloat32x4 (Bias);
284281 } else {
285- Accumulator = vdupq_n_f32 (0 .0f );
282+ Accumulator = MlasBroadcastFloat32x4 (0 .0f );
283+ }
284+ if (BiasAddition) {
285+ Accumulator = MlasAddFloat32x4 (Accumulator, MlasLoadFloat32x4 (Bias));
286286 }
287287
288288 for (size_t kh = 0 ; kh < KernelHeight; kh++) {
@@ -361,25 +361,27 @@ void
361361 const size_t OutputStrideElements = OutputStride / sizeof (float );
362362
363363 const size_t BlockSize = MlasNchwcGetBlockSize ();
364- const float32x4_t ZeroVector = vdupq_n_f32 (0 .0f );
364+ const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
365365
366- for (size_t i = 0 ; i < OutputCount; i ++) {
366+ for (size_t output_idx = 0 ; output_idx < OutputCount; output_idx ++) {
367367 for (size_t f = 0 ; f < FilterCount; f++) {
368368 const float * filter = Filter + f * FilterStrideElements;
369369 float * output = Output + f * OutputStrideElements;
370370 float32x4_t Accumulator;
371371 if (AccumulateOutput) {
372- Accumulator = MlasLoadFloat32x4 (&output[i * BlockSize]);
373- } else if (BiasAddition) {
374- Accumulator = MlasLoadFloat32x4 (&Bias[f * BlockSize]);
372+ Accumulator = MlasLoadFloat32x4 (&output[output_idx * BlockSize]);
375373 } else {
376- Accumulator = vdupq_n_f32 (0 .0f );
374+ Accumulator = MlasBroadcastFloat32x4 (0 .0f );
375+ }
376+ if (BiasAddition) {
377+ const float32x4_t BiasVector = MlasLoadFloat32x4 (&Bias[f * BlockSize]);
378+ Accumulator = MlasAddFloat32x4 (Accumulator, BiasVector);
377379 }
378380 for (size_t c = 0 ; c < InputChannels; c++) {
379- const float * input_ptr = Input + c * InputStrideElements + i * StrideWidthElements;
381+ const float * input_ptr = Input + c * InputStrideElements + output_idx * StrideWidthElements;
380382 for (size_t input_b = 0 ; input_b < BlockSize; input_b++) {
381383 const float input_value = input_ptr[input_b];
382- const float32x4_t InputVector = vdupq_n_f32 (input_value);
384+ const float32x4_t InputVector = MlasBroadcastFloat32x4 (input_value);
383385 const float * filter_ptr = filter + (c * BlockSize + input_b) * BlockSize;
384386 const float32x4_t FilterVector = MlasLoadFloat32x4 (filter_ptr);
385387 Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
388390 if (ReluActivation) {
389391 Accumulator = MlasMaximumFloat32x4 (Accumulator, ZeroVector);
390392 }
391- MlasStoreFloat32x4 (&output[i * BlockSize], Accumulator);
393+ MlasStoreFloat32x4 (&output[output_idx * BlockSize], Accumulator);
392394 }
393395 }
394396}
0 commit comments