@@ -317,90 +317,65 @@ void
317317 unsigned KernelFlags
318318 )
319319{
320- // Mark unused parameters
321- (void )InputBase;
322- (void )InputWidth;
323-
324320 const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0 ;
325321 const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0 ;
326322 const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
327323
324+ const size_t BlockSize = MlasNchwcGetBlockSize ();
328325 const float32x4_t ZeroVector = vdupq_n_f32 (0 .0f );
329326
330- // Convert byte strides to element strides
331327 const size_t StrideWidthElements = StrideWidth / sizeof (float );
332328 const size_t DilationWidthElements = DilationWidth / sizeof (float );
333329 const size_t InputStrideElements = InputStride / sizeof (float );
334330 const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof (float );
335331
336- // Mark unused variables in Depthwise kernel
337332 (void )InputStrideElements;
338333
339- // For depthwise convolution, process 4 channels at a time
340- const float bias = BiasAddition ? Bias[0 ] : 0 .0f ;
341- const float32x4_t BiasVector = vdupq_n_f32 (bias);
334+ const size_t InputWidthElements = InputWidth / sizeof (float );
342335
343- const float * input_ptr = Input;
344- size_t OutputIndex = 0 ;
336+ const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
345337
346- // Handle left padding
347- for (size_t i = 0 ; i < OutputCountLeftPad; i++) {
348- float32x4_t Accumulator = AccumulateOutput ? MlasLoadFloat32x4 (&Output[OutputIndex]) : BiasVector;
349-
350- for (size_t kh = 0 ; kh < KernelHeight; kh++) {
351- for (size_t kw = 0 ; kw < KernelWidth; kw++) {
352- const float * input_element = input_ptr + kh * DilatedInputWidthElements + kw * DilationWidthElements;
353- const float filter_value = Filter[kh * KernelWidth + kw];
354- const float32x4_t FilterVector = vdupq_n_f32 (filter_value);
355- const float32x4_t InputVector = MlasLoadFloat32x4 (input_element);
338+ for (size_t output_idx = 0 ; output_idx < TotalOutputCount; output_idx++) {
339+ bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount);
356340
357- Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
358- }
359- }
341+ float32x4_t Accumulator;
360342
361- if (ReluActivation) {
362- Accumulator = MlasMaximumFloat32x4 (Accumulator, ZeroVector);
343+ if (AccumulateOutput) {
344+ Accumulator = MlasLoadFloat32x4 (&Output[output_idx * BlockSize]);
345+ } else if (BiasAddition) {
346+ Accumulator = MlasLoadFloat32x4 (Bias);
347+ } else {
348+ Accumulator = vdupq_n_f32 (0 .0f );
363349 }
364350
365- MlasStoreFloat32x4 (&Output[OutputIndex], Accumulator);
366- OutputIndex += 4 ;
367- input_ptr += StrideWidthElements;
368- }
369-
370- // Handle main output region
371- for (size_t i = 0 ; i < OutputCount; i++) {
372- float32x4_t Accumulator = AccumulateOutput ? MlasLoadFloat32x4 (&Output[OutputIndex]) : BiasVector;
373-
374351 for (size_t kh = 0 ; kh < KernelHeight; kh++) {
375352 for (size_t kw = 0 ; kw < KernelWidth; kw++) {
376- const float * input_element = input_ptr + kh * DilatedInputWidthElements + kw * DilationWidthElements;
377- const float filter_value = Filter[kh * KernelWidth + kw];
378- const float32x4_t FilterVector = vdupq_n_f32 (filter_value);
379- const float32x4_t InputVector = MlasLoadFloat32x4 (input_element);
380-
381- Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
382- }
383- }
384-
385- if (ReluActivation) {
386- Accumulator = MlasMaximumFloat32x4 (Accumulator, ZeroVector);
387- }
388-
389- MlasStoreFloat32x4 (&Output[OutputIndex], Accumulator);
390- OutputIndex += 4 ;
391- input_ptr += StrideWidthElements;
392- }
393-
394- // Handle right padding
395- for (size_t i = 0 ; i < OutputCountRightPad; i++) {
396- float32x4_t Accumulator = AccumulateOutput ? MlasLoadFloat32x4 (&Output[OutputIndex]) : BiasVector;
353+ size_t kernel_pos = kh * KernelWidth + kw;
354+
355+ const float * input_base = Input + output_idx * StrideWidthElements +
356+ kh * DilatedInputWidthElements + kw * DilationWidthElements;
357+
358+ float32x4_t InputVector;
359+
360+ if (is_main_region) {
361+ InputVector = MlasLoadFloat32x4 (input_base);
362+ } else {
363+ float input_values[4 ];
364+ for (size_t i = 0 ; i < BlockSize; i++) {
365+ const float * input_element = input_base + i;
366+ const float * input_row_start = InputBase + kh * DilatedInputWidthElements;
367+ const float * input_row_end = input_row_start + InputWidthElements;
368+
369+ if (input_element >= input_row_start && input_element < input_row_end) {
370+ input_values[i] = *input_element;
371+ } else {
372+ input_values[i] = 0 .0f ;
373+ }
374+ }
375+ InputVector = MlasLoadFloat32x4 (input_values);
376+ }
397377
398- for (size_t kh = 0 ; kh < KernelHeight; kh++) {
399- for (size_t kw = 0 ; kw < KernelWidth; kw++) {
400- const float * input_element = input_ptr + kh * DilatedInputWidthElements + kw * DilationWidthElements;
401- const float filter_value = Filter[kh * KernelWidth + kw];
402- const float32x4_t FilterVector = vdupq_n_f32 (filter_value);
403- const float32x4_t InputVector = MlasLoadFloat32x4 (input_element);
378+ const float32x4_t FilterVector = MlasLoadFloat32x4 (&Filter[kernel_pos * BlockSize]);
404379
405380 Accumulator = MlasMultiplyAddFloat32x4 (InputVector, FilterVector, Accumulator);
406381 }
410385 Accumulator = MlasMaximumFloat32x4 (Accumulator, ZeroVector);
411386 }
412387
413- MlasStoreFloat32x4 (&Output[OutputIndex], Accumulator);
414- OutputIndex += 4 ;
415- input_ptr += StrideWidthElements;
388+ MlasStoreFloat32x4 (&Output[output_idx * BlockSize], Accumulator);
416389 }
417390}
418391
0 commit comments