106106 const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
107107
108108 const size_t BlockSize = MlasNchwcGetBlockSize ();
109+ const float32x4_t ZeroVector = vdupq_n_f32 (0 .0f );
109110
110111 const size_t StrideWidthElements = StrideWidth / sizeof (float );
111112 const size_t DilationWidthElements = DilationWidth / sizeof (float );
@@ -125,20 +126,17 @@ void
125126 const float * filter = Filter + filterSetBlock * FilterStrideElements;
126127 float * output = Output + filterSetBlock * OutputStrideElements;
127128
128- float accumulator[BlockSize] ;
129+ float32x4_t Accumulator ;
129130
130- for (size_t i = 0 ; i < BlockSize; i++) {
131- accumulator[i] = 0 .0f ;
132- }
133131 if (AccumulateOutput) {
134- for ( size_t i = 0 ; i < BlockSize; i++) {
135- accumulator[i] = output[output_idx * BlockSize + i];
136- }
132+ Accumulator = MlasLoadFloat32x4 (&output[output_idx * BlockSize]);
133+ } else {
134+ Accumulator = vdupq_n_f32 ( 0 . 0f );
137135 }
136+
138137 if (BiasAddition) {
139- for (size_t i = 0 ; i < BlockSize; i++) {
140- accumulator[i] += Bias[filterSetBlock * BlockSize + i];
141- }
138+ const float32x4_t BiasVector = MlasLoadFloat32x4 (&Bias[filterSetBlock * BlockSize]);
139+ Accumulator = vaddq_f32 (Accumulator, BiasVector);
142140 }
143141
144142 for (size_t kh = 0 ; kh < KernelHeight; kh++) {
@@ -147,37 +145,35 @@ void
147145 kh * DilatedInputWidthElements + kw * DilationWidthElements;
148146
149147 for (size_t filterBlock = 0 ; filterBlock < BlockSize; filterBlock++) {
150- for (size_t ic = 0 ; ic < BlockSize; ic++) {
151- size_t kernel_pos = kh * (KernelWidth * BlockSize * BlockSize) +
152- kw * (BlockSize * BlockSize) +
153- filterBlock * (BlockSize) +
154- ic;
155-
156- const float * input_element = input_base + filterBlock;
157- const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
158- const float * input_row_end = input_row_start + InputWidthElements;
159-
160- float input_value;
161- if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
162- input_value = *input_element;
163- } else {
164- input_value = 0 .0f ;
165- }
166-
167- float filter_value = filter[kernel_pos];
168- accumulator[ic] += input_value * filter_value;
148+ const float * input_element = input_base + filterBlock;
149+ const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
150+ const float * input_row_end = input_row_start + InputWidthElements;
151+
152+ float input_value;
153+ if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
154+ input_value = *input_element;
155+ } else {
156+ input_value = 0 .0f ;
169157 }
170- }
171158
159+ const float32x4_t InputVector = vdupq_n_f32 (input_value);
160+
161+ size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
162+ kw * (BlockSize * BlockSize) +
163+ filterBlock * BlockSize;
164+
165+ const float32x4_t FilterVector = MlasLoadFloat32x4 (&filter[kernel_base_pos]);
166+
167+ Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
168+ }
172169 }
173170 }
174171
175- for (size_t i = 0 ; i < BlockSize; i++) {
176- if (ReluActivation && accumulator[i] < 0 .0f ) {
177- accumulator[i] = 0 .0f ;
178- }
179- output[output_idx * BlockSize + i] = accumulator[i];
172+ if (ReluActivation) {
173+ Accumulator = MlasMaximumFloat32x4 (Accumulator, ZeroVector);
180174 }
175+
176+ MlasStoreFloat32x4 (&output[output_idx * BlockSize], Accumulator);
181177 }
182178 }
183179}
0 commit comments